From 4d0cf10089cfbbcd34a2ecad23c6d44230c5a4bf Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Wed, 10 Sep 2025 16:34:41 -0700 Subject: [PATCH 01/13] claude code docstrings finished, need to double check them --- pyproject.toml | 5 - ruff.toml | 38 ++ tests/data/test_hcs.py | 1 - tests/data/test_select.py | 1 - tests/data/test_triplet.py | 1 - tests/evaluation/test_cell_feature_metrics.py | 1 - tests/evaluation/test_evaluation_metrics.py | 1 - tests/preprocessing/generate_masks_tests.py | 3 +- tests/preprocessing/resize_images_tests.py | 7 +- tests/preprocessing/test_pixel_ratio.py | 1 - tests/representation/test_feature.py | 1 - tests/representation/test_lca.py | 1 - tests/translation/test_evaluation.py | 1 - tests/unet/networks/Unet25D_tests.py | 1 - tests/unet/networks/Unet2D_tests.py | 1 - .../unet/networks/layers/ConvBlock2D_tests.py | 1 - .../unet/networks/layers/ConvBlock3D_tests.py | 1 - tests/unet/test_fcmae.py | 1 - tests/utils/image_utils_tests.py | 1 - tests/utils/masks_utils_tests.py | 1 - tests/utils/mp_utils_tests.py | 3 +- viscy/__init__.py | 2 +- viscy/cli.py | 20 +- viscy/data/cell_classification.py | 121 ++++- viscy/data/combined.py | 257 +++++++++- viscy/data/ctmc_v1.py | 28 +- viscy/data/distributed.py | 10 +- viscy/data/gpu_aug.py | 120 ++++- viscy/data/hcs.py | 225 +++++++-- viscy/data/livecell.py | 78 +++ viscy/data/mmap_cache.py | 86 ++++ viscy/data/segmentation.py | 84 ++++ viscy/data/select.py | 8 +- viscy/data/triplet.py | 79 +++- viscy/data/typing.py | 25 +- viscy/preprocessing/generate_masks.py | 69 +-- viscy/preprocessing/pixel_ratio.py | 18 +- viscy/preprocessing/precompute.py | 22 +- viscy/representation/classification.py | 152 +++++- viscy/representation/contrastive.py | 1 - viscy/representation/embedding_writer.py | 19 +- viscy/representation/engine.py | 75 ++- viscy/representation/evaluation/__init__.py | 59 ++- viscy/representation/evaluation/clustering.py | 17 +- .../evaluation/dimensionality_reduction.py | 21 +- viscy/representation/evaluation/distance.py | 58 ++- viscy/representation/evaluation/feature.py | 107 ++--- viscy/representation/evaluation/lca.py | 35 +- .../evaluation/visualization.py | 257 +++++++--- viscy/representation/multi_modal.py | 140 +++++- viscy/trainer.py | 29 ++ viscy/transforms/_gaussian_blur.py | 3 +- viscy/transforms/_redef.py | 47 +- viscy/transforms/_transforms.py | 13 +- viscy/translation/engine.py | 445 +++++++++++++++--- viscy/translation/evaluation.py | 10 +- viscy/translation/evaluation_metrics.py | 418 ++++++++++------ viscy/translation/predict_writer.py | 81 +++- viscy/unet/networks/Unet25D.py | 118 +++-- viscy/unet/networks/Unet2D.py | 102 ++-- viscy/unet/networks/fcmae.py | 93 +++- viscy/unet/networks/layers/ConvBlock2D.py | 85 ++-- viscy/unet/networks/layers/ConvBlock3D.py | 84 +++- viscy/unet/networks/unext2.py | 179 ++++++- viscy/utils/aux_utils.py | 93 ++-- viscy/utils/cli_utils.py | 78 ++- viscy/utils/image_utils.py | 129 +++-- viscy/utils/log_images.py | 2 +- viscy/utils/logging.py | 106 +++-- viscy/utils/masks.py | 158 +++++-- viscy/utils/meta_utils.py | 268 ++++++----- viscy/utils/mp_utils.py | 245 ++++++---- viscy/utils/normalize.py | 129 +++-- 73 files changed, 3975 insertions(+), 1205 deletions(-) create mode 100644 ruff.toml diff --git a/pyproject.toml b/pyproject.toml index 39ba15170..04b11a58a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,8 +74,3 @@ packages = ["viscy"] [tool.setuptools_scm] write_to = "viscy/_version.py" -[tool.ruff] -src = ["viscy", "tests"] -line-length = 88 -lint.extend-select = ["I001"] -lint.isort.known-first-party = ["viscy"] diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..c3e0dbc8a --- /dev/null +++ b/ruff.toml @@ -0,0 +1,38 @@ +# This file is used to configure the Ruff linter and formatter: +# View the documentation for more information on how to configure this file below +# https://docs.astral.sh/ruff/linter/ +# https://docs.astral.sh/ruff/formatter/ + + +line-length = 88 +src = ["viscy", "tests"] +extend-include = ["*.ipynb"] +target-version = "py310" +# Exclude the following for now. Later on we should check every Python file, no exceptions. +extend-exclude = ["viscy/scripts/*", "applications/*", "examples/*"] + +[format] +quote-style = "double" +indent-style = "space" +docstring-code-format = true +docstring-code-line-length = "dynamic" + +[lint] +select = [ + "D", # pydocstyle + "I", # isort +] +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # __magic__ methods are often self-explanatory, allow missing docstrings + "D107", # Missing docstring in __init__ + # Disable one in each pair of mutually incompatible rules + "D203", # We don’t want a blank line before a class docstring + "D213", # <> We want docstrings to start immediately after the opening triple quote + "D400", # first line should end with a period [Bug: doesn’t work with single-line docstrings] + "D401", # First line should be in imperative mood; try rephrasing +] +per-file-ignores."*/__init__.py" = ["F401"] +per-file-ignores."tests/*" = ["D"] +pydocstyle.convention = "numpy" \ No newline at end of file diff --git a/tests/data/test_hcs.py b/tests/data/test_hcs.py index c71488c4c..8d040e827 100644 --- a/tests/data/test_hcs.py +++ b/tests/data/test_hcs.py @@ -3,7 +3,6 @@ from iohub import open_ome_zarr from monai.transforms import RandSpatialCropSamplesd from pytest import mark - from viscy.data.hcs import HCSDataModule from viscy.trainer import VisCyTrainer diff --git a/tests/data/test_select.py b/tests/data/test_select.py index 6ffc55b83..98b38f312 100644 --- a/tests/data/test_select.py +++ b/tests/data/test_select.py @@ -1,6 +1,5 @@ import pytest from iohub.ngff import open_ome_zarr - from viscy.data.select import SelectWell diff --git a/tests/data/test_triplet.py b/tests/data/test_triplet.py index d97ee0a43..0d85c1535 100644 --- a/tests/data/test_triplet.py +++ b/tests/data/test_triplet.py @@ -1,7 +1,6 @@ import pandas as pd from iohub import open_ome_zarr from pytest import mark - from viscy.data.triplet import TripletDataModule diff --git a/tests/evaluation/test_cell_feature_metrics.py b/tests/evaluation/test_cell_feature_metrics.py index d118fd381..3feb6596c 100644 --- a/tests/evaluation/test_cell_feature_metrics.py +++ b/tests/evaluation/test_cell_feature_metrics.py @@ -2,7 +2,6 @@ import pandas as pd import pytest from skimage import measure - from viscy.representation.evaluation.feature import CellFeatures, DynamicFeatures diff --git a/tests/evaluation/test_evaluation_metrics.py b/tests/evaluation/test_evaluation_metrics.py index af1c8411b..1d5afc5ee 100644 --- a/tests/evaluation/test_evaluation_metrics.py +++ b/tests/evaluation/test_evaluation_metrics.py @@ -3,7 +3,6 @@ import torch from skimage import data, measure from skimage.util import img_as_float - from viscy.translation.evaluation_metrics import ( POD_metric, VOI_metric, diff --git a/tests/preprocessing/generate_masks_tests.py b/tests/preprocessing/generate_masks_tests.py index 45f2e4f4a..6916de830 100644 --- a/tests/preprocessing/generate_masks_tests.py +++ b/tests/preprocessing/generate_masks_tests.py @@ -8,7 +8,6 @@ import pandas as pd import skimage.io as sk_im_io from testfixtures import TempDirectory - from viscy.preprocessing.generate_masks import MaskProcessor from viscy.utils import aux_utils as aux_utils @@ -129,7 +128,7 @@ def test_generate_masks_uni(self): nose.tools.assert_equal(len(frames_meta), exp_len) for idx in range(exp_len): nose.tools.assert_equal( - "im_c003_z00{}_t000_p001.npy".format(idx), + f"im_c003_z00{idx}_t000_p001.npy", frames_meta.iloc[idx]["file_name"], ) diff --git a/tests/preprocessing/resize_images_tests.py b/tests/preprocessing/resize_images_tests.py index 835f3de4b..5d237b83b 100644 --- a/tests/preprocessing/resize_images_tests.py +++ b/tests/preprocessing/resize_images_tests.py @@ -4,10 +4,9 @@ import cv2 import numpy as np import pandas as pd -from testfixtures import TempDirectory - import viscy.preprocessing.resize_images as resize_images import viscy.utils.aux_utils as aux_utils +from testfixtures import TempDirectory class TestResizeImages(unittest.TestCase): @@ -133,7 +132,7 @@ def test_resize_volumes(self): ), ignore_index=True, ) - op_fname = "im_c00{}_z000_t005_p007_3.3-0.8-1.0.npy".format(c) + op_fname = f"im_c00{c}_z000_t005_p007_3.3-0.8-1.0.npy" exp_meta_dict.append( { "time_idx": self.time_idx, @@ -169,7 +168,7 @@ def test_resize_volumes(self): exp_meta_dict = [] for c in channel_ids: for s in [0, 2]: - op_fname = "im_c00{}_z00{}_t005_p007_3.3-0.8-1.0.npy".format(c, s) + op_fname = f"im_c00{c}_z00{s}_t005_p007_3.3-0.8-1.0.npy" exp_meta_dict.append( { "time_idx": self.time_idx, diff --git a/tests/preprocessing/test_pixel_ratio.py b/tests/preprocessing/test_pixel_ratio.py index 0251fefc1..85e565e8b 100644 --- a/tests/preprocessing/test_pixel_ratio.py +++ b/tests/preprocessing/test_pixel_ratio.py @@ -1,5 +1,4 @@ from numpy.testing import assert_allclose - from viscy.preprocessing.pixel_ratio import sematic_class_weights diff --git a/tests/representation/test_feature.py b/tests/representation/test_feature.py index 5bd83ebdc..90a7c09fc 100644 --- a/tests/representation/test_feature.py +++ b/tests/representation/test_feature.py @@ -4,7 +4,6 @@ import pandas as pd import pytest from iohub import open_ome_zarr - from viscy.representation.evaluation.feature import ( CellFeatures, DynamicFeatures, diff --git a/tests/representation/test_lca.py b/tests/representation/test_lca.py index 1794804d1..f87ad3d33 100644 --- a/tests/representation/test_lca.py +++ b/tests/representation/test_lca.py @@ -1,7 +1,6 @@ import numpy as np import torch from sklearn.linear_model import LogisticRegression - from viscy.representation.evaluation.lca import linear_from_binary_logistic_regression diff --git a/tests/translation/test_evaluation.py b/tests/translation/test_evaluation.py index ebfbaff8c..d883ff62a 100644 --- a/tests/translation/test_evaluation.py +++ b/tests/translation/test_evaluation.py @@ -6,7 +6,6 @@ import pytest from lightning.pytorch.loggers import CSVLogger from numpy.testing import assert_array_equal - from viscy.data.segmentation import SegmentationDataModule from viscy.trainer import Trainer from viscy.translation.evaluation import SegmentationMetrics2D diff --git a/tests/unet/networks/Unet25D_tests.py b/tests/unet/networks/Unet25D_tests.py index f954d8873..135d6cdc6 100644 --- a/tests/unet/networks/Unet25D_tests.py +++ b/tests/unet/networks/Unet25D_tests.py @@ -4,7 +4,6 @@ import numpy as np import torch - import viscy.utils.cli_utils as io_utils from viscy.unet.networks.Unet25D import Unet25d diff --git a/tests/unet/networks/Unet2D_tests.py b/tests/unet/networks/Unet2D_tests.py index 3f69f2145..7ea8c4f3f 100644 --- a/tests/unet/networks/Unet2D_tests.py +++ b/tests/unet/networks/Unet2D_tests.py @@ -4,7 +4,6 @@ import numpy as np import torch - import viscy.utils.cli_utils as io_utils from viscy.unet.networks.Unet2D import Unet2d diff --git a/tests/unet/networks/layers/ConvBlock2D_tests.py b/tests/unet/networks/layers/ConvBlock2D_tests.py index f708e8008..876421ab7 100644 --- a/tests/unet/networks/layers/ConvBlock2D_tests.py +++ b/tests/unet/networks/layers/ConvBlock2D_tests.py @@ -4,7 +4,6 @@ import numpy as np import torch - import viscy.utils.cli_utils as io_utils from viscy.unet.networks.layers.ConvBlock2D import ConvBlock2D diff --git a/tests/unet/networks/layers/ConvBlock3D_tests.py b/tests/unet/networks/layers/ConvBlock3D_tests.py index 60fbd3ef2..4760fcf0e 100644 --- a/tests/unet/networks/layers/ConvBlock3D_tests.py +++ b/tests/unet/networks/layers/ConvBlock3D_tests.py @@ -4,7 +4,6 @@ import numpy as np import torch - import viscy.utils.cli_utils as io_utils from viscy.unet.networks.layers.ConvBlock3D import ConvBlock3D diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index f22efa4c8..b044bfc1b 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,5 +1,4 @@ import torch - from viscy.unet.networks.fcmae import ( FullyConvolutionalMAE, MaskedAdaptiveProjection, diff --git a/tests/utils/image_utils_tests.py b/tests/utils/image_utils_tests.py index f90ab0e34..761569ebb 100644 --- a/tests/utils/image_utils_tests.py +++ b/tests/utils/image_utils_tests.py @@ -1,5 +1,4 @@ import numpy as np - from viscy.utils.image_utils import grid_sample_pixel_values, preprocess_image diff --git a/tests/utils/masks_utils_tests.py b/tests/utils/masks_utils_tests.py index 71d437db5..42213ad4f 100644 --- a/tests/utils/masks_utils_tests.py +++ b/tests/utils/masks_utils_tests.py @@ -2,7 +2,6 @@ import numpy as np from skimage import draw from skimage.filters import gaussian - from viscy.utils.masks import ( create_unimodal_mask, get_unet_border_weight_map, diff --git a/tests/utils/mp_utils_tests.py b/tests/utils/mp_utils_tests.py index 89b452550..1504cfaa8 100644 --- a/tests/utils/mp_utils_tests.py +++ b/tests/utils/mp_utils_tests.py @@ -5,11 +5,10 @@ import numpy as np import numpy.testing import skimage.io as sk_im_io -from testfixtures import TempDirectory - import viscy.utils.aux_utils as aux_utils import viscy.utils.image_utils as image_utils import viscy.utils.mp_utils as mp_utils +from testfixtures import TempDirectory from viscy.utils.masks import create_otsu_mask diff --git a/viscy/__init__.py b/viscy/__init__.py index 31573ed3c..5f1ef031e 100644 --- a/viscy/__init__.py +++ b/viscy/__init__.py @@ -1 +1 @@ -"""Learning vision for cells""" +"""Learning vision for cells.""" diff --git a/viscy/cli.py b/viscy/cli.py index 0c07787ad..93f3d118b 100644 --- a/viscy/cli.py +++ b/viscy/cli.py @@ -1,3 +1,5 @@ +"""Lightning CLI for computer vision models in VisCy.""" + import logging import os import sys @@ -17,6 +19,13 @@ class VisCyCLI(LightningCLI): @staticmethod def subcommands() -> dict[str, set[str]]: + """Define subcommands and their required arguments. + + Returns + ------- + dict[str, set[str]] + Dictionary mapping subcommand names to sets of required argument names. + """ subcommands = LightningCLI.subcommands() subcommand_base_args = {"model"} subcommands["preprocess"] = subcommand_base_args @@ -25,6 +34,13 @@ def subcommands() -> dict[str, set[str]]: return subcommands def add_arguments_to_parser(self, parser) -> None: + """Add default arguments to the Lightning CLI parser. + + Parameters + ---------- + parser + Lightning CLI parser instance to configure. + """ parser.set_defaults( { "trainer.logger": lazy_instance( @@ -45,8 +61,8 @@ def _setup_environment() -> None: def main() -> None: - """ - Main Lightning CLI entry point. + """Run the Lightning CLI entry point. + Parse log level and set TF32 precision. Set default random seed to 42. """ diff --git a/viscy/data/cell_classification.py b/viscy/data/cell_classification.py index ac72a6601..292e0507f 100644 --- a/viscy/data/cell_classification.py +++ b/viscy/data/cell_classification.py @@ -1,5 +1,7 @@ +"""Dataset and DataModule classes for cell classification tasks.""" + +from collections.abc import Callable from pathlib import Path -from typing import Callable import pandas as pd import torch @@ -14,6 +16,29 @@ class ClassificationDataset(Dataset): + """Dataset for cell classification tasks. + + A PyTorch Dataset that provides cell patches and labels for classification. + Loads image patches from HCS OME-Zarr data based on cell annotations. + + Parameters + ---------- + plate : Plate + HCS OME-Zarr plate containing image data. + annotation : pd.DataFrame + DataFrame with cell annotations and labels. + channel_name : str + Name of the image channel to load. + z_range : tuple[int, int] + Range of Z slices to include (start, end). + transform : Callable | None, optional + Transform to apply to image patches. + initial_yx_patch_size : tuple[int, int] + Initial patch size in Y and X dimensions. + return_indices : bool + Whether to return cell indices with patches, by default False. + """ + def __init__( self, plate: Plate, @@ -44,11 +69,25 @@ def __init__( ] def __len__(self): + """Return the number of samples in the dataset.""" return len(self.annotation) def __getitem__( self, idx ) -> tuple[Tensor, Tensor] | tuple[Tensor, Tensor, dict[str, int | str]]: + """ + Get a sample from the dataset. + + Parameters + ---------- + idx : int + Index of the sample to retrieve. + + Returns + ------- + tuple[Tensor, Tensor] or tuple[Tensor, Tensor, dict[str, int | str]] + Image tensor, label tensor, and optionally cell indices. + """ row = self.annotation.iloc[idx] fov_name, t, y, x = row["fov_name"], row["t"], row["y"], row["x"] fov = self.plate[fov_name] @@ -74,6 +113,37 @@ def __getitem__( class ClassificationDataModule(LightningDataModule): + """Lightning DataModule for cell classification tasks. + + Manages data loading and preprocessing for cell classification workflows. + Handles train/validation splits and applies appropriate transforms. + + Parameters + ---------- + image_path : Path + Path to HCS OME-Zarr image data. + annotation_path : Path + Path to cell annotation CSV file. + val_fovs : list[str], optional + List of FOV names to use for validation. + channel_name : str + Name of the image channel to load. + z_range : tuple[int, int] + Range of Z slices to include (start, end). + train_exlude_timepoints : list[int] + Timepoints to exclude from training data. + train_transforms : list[Callable], optional + List of transforms to apply to training data. + val_transforms : list[Callable], optional + List of transforms to apply to validation data. + initial_yx_patch_size : tuple[int, int] + Initial patch size in Y and X dimensions. + batch_size : int + Batch size for data loading. + num_workers : int + Number of workers for data loading. + """ + def __init__( self, image_path: Path, @@ -123,7 +193,22 @@ def _subset( return_indices=return_indices, ) - def setup(self, stage=None): + def setup(self, stage=None) -> None: + """ + Set up datasets for the specified stage. + + Parameters + ---------- + stage : str, optional + Stage to set up for ('fit', 'validate', 'predict', 'test'). + + Raises + ------ + NotImplementedError + If stage is 'test'. + ValueError + If stage is unknown. + """ plate = open_ome_zarr(self.image_path) all_fovs = ["/" + name for (name, _) in plate.positions()] annotation = pd.read_csv(self.annotation_path) @@ -158,9 +243,17 @@ def setup(self, stage=None): elif stage == "test": raise NotImplementedError("Test stage not implemented.") else: - raise (f"Unknown stage: {stage}") + raise ValueError(f"Unknown stage: {stage}") - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: + """ + Create training data loader. + + Returns + ------- + DataLoader + Training data loader with shuffling enabled. + """ return DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -168,7 +261,15 @@ def train_dataloader(self): shuffle=True, ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: + """ + Create validation data loader. + + Returns + ------- + DataLoader + Validation data loader without shuffling. + """ return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -176,7 +277,15 @@ def val_dataloader(self): shuffle=False, ) - def predict_dataloader(self): + def predict_dataloader(self) -> DataLoader: + """ + Create prediction data loader. + + Returns + ------- + DataLoader + Prediction data loader without shuffling. + """ return DataLoader( self.predict_dataset, batch_size=self.batch_size, diff --git a/viscy/data/combined.py b/viscy/data/combined.py index d4f2c8330..90e111729 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,8 +1,16 @@ +"""Combined data modules for multi-dataset ML training workflows. + +This module provides Lightning DataModule implementations for combining multiple +data sources with various strategies including concatenation, batching, and +distributed sampling optimizations for computer vision and microscopy datasets. +""" + import bisect import logging from collections import defaultdict +from collections.abc import Sequence from enum import Enum -from typing import Literal, Sequence +from typing import Literal import torch from lightning.pytorch import LightningDataModule @@ -17,6 +25,12 @@ class CombineMode(Enum): + """Enumeration of data combination modes for CombinedDataModule. + + Defines how multiple data modules should be combined during training, + validation, and testing phases. + """ + MIN_SIZE = "min_size" MAX_SIZE_CYCLE = "max_size_cycle" MAX_SIZE = "max_size" @@ -25,6 +39,7 @@ class CombineMode(Enum): class CombinedDataModule(LightningDataModule): """Wrapper for combining multiple data modules. + For supported modes, see ``lightning.pytorch.utilities.combined_loader``. Parameters @@ -57,31 +72,71 @@ def __init__( self.predict_mode = CombineMode(predict_mode).value self.prepare_data_per_node = True - def prepare_data(self): + def prepare_data(self) -> None: + """Prepare data for all constituent data modules. + + Propagates trainer reference and calls prepare_data on each + data module for dataset downloading and preprocessing. + """ for dm in self.data_modules: dm.trainer = self.trainer dm.prepare_data() - def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: + """Set up data modules for specified training stage. + + Parameters + ---------- + stage : Literal["fit", "validate", "test", "predict"] + Current training stage for Lightning setup. + """ for dm in self.data_modules: dm.setup(stage) - def train_dataloader(self): + def train_dataloader(self) -> CombinedLoader: + """Create combined training dataloader. + + Returns + ------- + CombinedLoader + Combined dataloader using specified train_mode strategy. + """ return CombinedLoader( [dm.train_dataloader() for dm in self.data_modules], mode=self.train_mode ) - def val_dataloader(self): + def val_dataloader(self) -> CombinedLoader: + """Create combined validation dataloader. + + Returns + ------- + CombinedLoader + Combined dataloader using specified val_mode strategy. + """ return CombinedLoader( [dm.val_dataloader() for dm in self.data_modules], mode=self.val_mode ) - def test_dataloader(self): + def test_dataloader(self) -> CombinedLoader: + """Create combined test dataloader. + + Returns + ------- + CombinedLoader + Combined dataloader using specified test_mode strategy. + """ return CombinedLoader( [dm.test_dataloader() for dm in self.data_modules], mode=self.test_mode ) - def predict_dataloader(self): + def predict_dataloader(self) -> CombinedLoader: + """Create combined prediction dataloader. + + Returns + ------- + CombinedLoader + Combined dataloader using specified predict_mode strategy. + """ return CombinedLoader( [dm.predict_dataloader() for dm in self.data_modules], mode=self.predict_mode, @@ -89,10 +144,45 @@ def predict_dataloader(self): class BatchedConcatDataset(ConcatDataset): - def __getitem__(self, idx): + """Batched concatenated dataset for efficient multi-dataset sampling. + + Extends PyTorch's ConcatDataset to support batched item retrieval + from multiple datasets with optimized index grouping for ML training. + """ + + def __getitem__(self, idx: int): + """Retrieve single item by index. + + Parameters + ---------- + idx : int + Sample index across concatenated datasets. + + Raises + ------ + NotImplementedError + Single item access not implemented; use __getitems__ instead. + """ raise NotImplementedError def _get_sample_indices(self, idx: int) -> tuple[int, int]: + """Map global index to dataset and sample indices. + + Parameters + ---------- + idx : int + Global index across all concatenated datasets. + + Returns + ------- + tuple[int, int] + Dataset index and local sample index within that dataset. + + Raises + ------ + ValueError + If absolute index value exceeds dataset length. + """ if idx < 0: if -idx > len(self): raise ValueError( @@ -107,6 +197,21 @@ def _get_sample_indices(self, idx: int) -> tuple[int, int]: return dataset_idx, sample_idx def __getitems__(self, indices: list[int]) -> list: + """Retrieve multiple items by indices with batched dataset access. + + Groups indices by source dataset and performs batched retrieval + for improved data loading performance during ML training. + + Parameters + ---------- + indices : list[int] + List of global indices across concatenated datasets. + + Returns + ------- + list + Samples from all requested indices, maintaining order. + """ grouped_indices = defaultdict(list) for idx in indices: dataset_idx, sample_indices = self._get_sample_indices(idx) @@ -151,11 +256,33 @@ def __init__(self, data_modules: Sequence[LightningDataModule]): self.prepare_data_per_node = True def prepare_data(self): + """Prepare data for all constituent data modules. + + Propagates trainer reference and calls prepare_data on each + data module for dataset preparation and preprocessing. + """ for dm in self.data_modules: dm.trainer = self.trainer dm.prepare_data() def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + """Set up concatenated datasets for training stage. + + Validates patch configuration consistency across data modules + and creates concatenated train/validation datasets. + + Parameters + ---------- + stage : Literal["fit", "validate", "test", "predict"] + Training stage - only "fit" currently supported. + + Raises + ------ + ValueError + If patches per stack are inconsistent across data modules. + NotImplementedError + If stage other than "fit" is requested. + """ self.train_patches_per_stack = 0 for dm in self.data_modules: dm.setup(stage) @@ -174,6 +301,14 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): ) def _dataloader_kwargs(self) -> dict: + """Get common dataloader configuration parameters. + + Returns + ------- + dict + Common PyTorch DataLoader configuration parameters including + worker settings, memory pinning, and prefetch configuration. + """ return { "num_workers": self.num_workers, "persistent_workers": self.persistent_workers, @@ -181,7 +316,15 @@ def _dataloader_kwargs(self) -> dict: "pin_memory": self.pin_memory, } - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: + """Create training dataloader for concatenated datasets. + + Returns + ------- + DataLoader + PyTorch DataLoader with shuffling enabled, batch size adjusted + for patch stacking, and sample collation for ML training. + """ return DataLoader( self.train_dataset, shuffle=True, @@ -191,7 +334,15 @@ def train_dataloader(self): **self._dataloader_kwargs(), ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: + """Create validation dataloader for concatenated datasets. + + Returns + ------- + DataLoader + PyTorch DataLoader without shuffling for deterministic + validation evaluation. + """ return DataLoader( self.val_dataset, shuffle=False, @@ -202,9 +353,23 @@ def val_dataloader(self): class BatchedConcatDataModule(ConcatDataModule): + """Concatenated data module with batched dataset access. + + Extends ConcatDataModule to use BatchedConcatDataset and + ThreadDataLoader for optimized multi-dataset training performance. + """ + _ConcatDataset = BatchedConcatDataset - def train_dataloader(self): + def train_dataloader(self) -> ThreadDataLoader: + """Create threaded training dataloader for batched access. + + Returns + ------- + ThreadDataLoader + MONAI ThreadDataLoader with thread-based workers for + optimized batched dataset access during training. + """ return ThreadDataLoader( self.train_dataset, use_thread_workers=True, @@ -214,7 +379,15 @@ def train_dataloader(self): **self._dataloader_kwargs(), ) - def val_dataloader(self): + def val_dataloader(self) -> ThreadDataLoader: + """Create threaded validation dataloader for batched access. + + Returns + ------- + ThreadDataLoader + MONAI ThreadDataLoader with thread-based workers for + optimized validation data loading. + """ return ThreadDataLoader( self.val_dataset, use_thread_workers=True, @@ -226,6 +399,13 @@ def val_dataloader(self): class CachedConcatDataModule(LightningDataModule): + """Cached concatenated data module for distributed training. + + Concatenates multiple data modules with support for distributed + sampling and caching optimizations for large-scale ML training. + # TODO: MANUAL_REVIEW - Verify caching behavior and memory usage + """ + def __init__(self, data_modules: Sequence[LightningDataModule]): super().__init__() self.data_modules = data_modules @@ -239,11 +419,33 @@ def __init__(self, data_modules: Sequence[LightningDataModule]): self.prepare_data_per_node = True def prepare_data(self): + """Prepare data for all constituent data modules. + + Propagates trainer reference and calls prepare_data on each + data module for dataset preparation and caching setup. + """ for dm in self.data_modules: dm.trainer = self.trainer dm.prepare_data() def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + """Set up cached concatenated datasets for distributed training. + + Validates patch configuration and creates concatenated datasets + with caching optimizations for efficient distributed access. + + Parameters + ---------- + stage : Literal["fit", "validate", "test", "predict"] + Training stage - only "fit" currently supported. + + Raises + ------ + ValueError + If patches per stack are inconsistent across data modules. + NotImplementedError + If stage other than "fit" is requested. + """ self.train_patches_per_stack = 0 for dm in self.data_modules: dm.setup(stage) @@ -262,6 +464,21 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): def _maybe_sampler( self, dataset: Dataset, shuffle: bool ) -> ShardedDistributedSampler | None: + """Create distributed sampler if in distributed training mode. + + Parameters + ---------- + dataset : Dataset + PyTorch dataset to create sampler for. + shuffle : bool + Whether to shuffle samples across distributed processes. + + Returns + ------- + ShardedDistributedSampler | None + Distributed sampler if PyTorch distributed is initialized, + None otherwise for single-process training. + """ return ( ShardedDistributedSampler(dataset, shuffle=shuffle) if torch.distributed.is_initialized() @@ -269,6 +486,14 @@ def _maybe_sampler( ) def train_dataloader(self) -> DataLoader: + """Create training dataloader with distributed sampling support. + + Returns + ------- + DataLoader + PyTorch DataLoader with distributed sampler if available, + configured for cached dataset access during training. + """ sampler = self._maybe_sampler(self.train_dataset, shuffle=True) return DataLoader( self.train_dataset, @@ -282,6 +507,14 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: + """Create validation dataloader with distributed sampling support. + + Returns + ------- + DataLoader + PyTorch DataLoader with distributed sampler if available, + configured for deterministic validation evaluation. + """ sampler = self._maybe_sampler(self.val_dataset, shuffle=False) return DataLoader( self.val_dataset, diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 3c888175c..811a1c6ba 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -1,3 +1,5 @@ +"""Data module for CTMCv1 autoregression dataset with HCS OME-Zarr stores.""" + from pathlib import Path import torch @@ -10,6 +12,7 @@ class CTMCv1DataModule(GPUTransformDataModule): """ Autoregression data module for the CTMCv1 dataset. + Training and validation datasets are stored in separate HCS OME-Zarr stores. Parameters @@ -18,13 +21,13 @@ class CTMCv1DataModule(GPUTransformDataModule): Path to the training dataset. val_data_path : str or Path Path to the validation dataset. - train_cpu_transforms : list of MapTransform + train_cpu_transforms : list[MapTransform] List of CPU transforms for training. - val_cpu_transforms : list of MapTransform + val_cpu_transforms : list[MapTransform] List of CPU transforms for validation. - train_gpu_transforms : list of MapTransform + train_gpu_transforms : list[MapTransform] List of GPU transforms for training. - val_gpu_transforms : list of MapTransform + val_gpu_transforms : list[MapTransform] List of GPU transforms for validation. batch_size : int, optional Batch size, by default 16. @@ -68,21 +71,38 @@ def __init__( @property def train_cpu_transforms(self) -> Compose: + """Get composed training CPU transforms.""" return self._train_cpu_transforms @property def val_cpu_transforms(self) -> Compose: + """Get composed validation CPU transforms.""" return self._val_cpu_transforms @property def train_gpu_transforms(self) -> Compose: + """Get composed training GPU transforms.""" return self._train_gpu_transforms @property def val_gpu_transforms(self) -> Compose: + """Get composed validation GPU transforms.""" return self._val_gpu_transforms def setup(self, stage: str) -> None: + """ + Set up datasets for the specified stage. + + Parameters + ---------- + stage : str + The stage to set up for. Only "fit" is currently supported. + + Raises + ------ + NotImplementedError + If stage is not "fit". + """ if stage != "fit": raise NotImplementedError("Only fit stage is supported") self._setup_fit() diff --git a/viscy/data/distributed.py b/viscy/data/distributed.py index 68e6d39e5..badce1ab3 100644 --- a/viscy/data/distributed.py +++ b/viscy/data/distributed.py @@ -14,9 +14,17 @@ class ShardedDistributedSampler(DistributedSampler): + """Distributed sampler that creates sharded random permutations. + + A specialized DistributedSampler that generates sharded random permutations + to ensure proper data distribution across multiple processes in DDP training. + """ + def _sharded_randperm(self, max_size: int, generator: Generator) -> list[int]: """Generate a sharded random permutation of indices. - Overlap may occur in between the last two shards to maintain divisibility.""" + + Overlap may occur in between the last two shards to maintain divisibility. + """ sharded_randperm = [ torch.randperm(self.num_samples, generator=generator) + min(i * self.num_samples, max_size - self.num_samples) diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index 5eb200f33..abca552e5 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -49,6 +49,16 @@ def _maybe_sampler( ) def train_dataloader(self) -> DataLoader: + """Create GPU-optimized training data loader. + + Configures distributed sampling, persistent workers, and memory pinning + for efficient GPU-accelerated batch processing during training. + + Returns + ------- + DataLoader + Training data loader with GPU optimization settings. + """ sampler = self._maybe_sampler(self.train_dataset, shuffle=True) _logger.debug(f"Using training sampler {sampler}") return DataLoader( @@ -65,6 +75,16 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: + """Create GPU-optimized validation data loader. + + Configures distributed sampling and memory pinning for efficient + GPU-accelerated batch processing during validation phase. + + Returns + ------- + DataLoader + Validation data loader with GPU optimization settings. + """ sampler = self._maybe_sampler(self.val_dataset, shuffle=False) _logger.debug(f"Using validation sampler {sampler}") return DataLoader( @@ -82,19 +102,63 @@ def val_dataloader(self) -> DataLoader: @property @abstractmethod - def train_cpu_transforms(self) -> Compose: ... + def train_cpu_transforms(self) -> Compose: + """CPU-based transform pipeline for training data. + + Returns pre-GPU augmentation transforms executed on CPU before + GPU transfer to optimize memory bandwidth and device utilization. + + Returns + ------- + Compose + Composed CPU transforms for training preprocessing. + """ + ... @property @abstractmethod - def train_gpu_transforms(self) -> Compose: ... + def train_gpu_transforms(self) -> Compose: + """GPU-accelerated transform pipeline for training data. + + Returns GPU-resident transforms for high-performance augmentation + with device memory optimization during training workflows. + + Returns + ------- + Compose + Composed GPU transforms for training augmentation. + """ + ... @property @abstractmethod - def val_cpu_transforms(self) -> Compose: ... + def val_cpu_transforms(self) -> Compose: + """CPU-based transform pipeline for validation data. + + Returns pre-GPU validation transforms executed on CPU for + deterministic preprocessing before GPU transfer. + + Returns + ------- + Compose + Composed CPU transforms for validation preprocessing. + """ + ... @property @abstractmethod - def val_gpu_transforms(self) -> Compose: ... + def val_gpu_transforms(self) -> Compose: + """GPU-accelerated transform pipeline for validation data. + + Returns GPU-resident transforms for consistent device-optimized + preprocessing during validation phase. + + Returns + ------- + Compose + Composed GPU transforms for validation processing. + """ + ... class CachedOmeZarrDataset(Dataset): @@ -147,7 +211,7 @@ def __init__( def __len__(self) -> int: return len(self._metadata_map) - def __getitem__(self, idx: int) -> dict[str, Tensor]: + def __getitem__(self, idx: int) -> dict[str, Tensor] | list[dict[str, Tensor]]: position, time_idx, norm_meta = self._metadata_map[idx] cache = self._cache_map[idx] if cache is None: @@ -240,18 +304,46 @@ def __init__( @property def train_cpu_transforms(self) -> Compose: + """CPU-based transform pipeline for training data. + + Returns + ------- + Compose + Composed CPU transforms applied before GPU transfer. + """ return self._train_cpu_transforms @property def train_gpu_transforms(self) -> Compose: + """GPU-accelerated transform pipeline for training data. + + Returns + ------- + Compose + Composed GPU transforms for device-optimized augmentation. + """ return self._train_gpu_transforms @property def val_cpu_transforms(self) -> Compose: + """CPU-based transform pipeline for validation data. + + Returns + ------- + Compose + Composed CPU transforms applied before GPU transfer. + """ return self._val_cpu_transforms @property def val_gpu_transforms(self) -> Compose: + """GPU-accelerated transform pipeline for validation data. + + Returns + ------- + Compose + Composed GPU transforms for device-optimized processing. + """ return self._val_gpu_transforms def _set_fit_global_state(self, num_positions: int) -> list[int]: @@ -279,6 +371,24 @@ def _filter_fit_fovs(self, plate: Plate) -> list[Position]: return positions def setup(self, stage: Literal["fit", "validate"]) -> None: + """Set up datasets with GPU-optimized caching and memory management. + + Configures train/validation split with shared memory caching for + efficient GPU batch loading. Initializes MONAI metadata tracking + and distributed data sampling. + + Parameters + ---------- + stage : Literal["fit", "validate"] + PyTorch Lightning stage for dataset configuration. + + Raises + ------ + NotImplementedError + If stage is not "fit" or "validate". + ValueError + If fewer than 2 FOVs available for train/validation split. + """ if stage not in ("fit", "validate"): raise NotImplementedError("Only fit and validate stages are supported.") cache_map = Manager().dict() diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 4a7e9dfb5..c5501c385 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -1,10 +1,13 @@ +"""High-Content Screening (HCS) data loading and preprocessing module.""" + import logging import math import os import re import tempfile +from collections.abc import Callable, Sequence from pathlib import Path -from typing import Callable, Literal, Sequence +from typing import Literal import numpy as np import torch @@ -49,7 +52,26 @@ def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: def _search_int_in_str(pattern: str, file_name: str) -> str: """Search image indices in a file name with regex patterns and strip leading zeros. - E.g. ``'001'`` -> ``1``""" + + E.g. ``'001'`` -> ``1``. + + Parameters + ---------- + pattern : str + Regex pattern to search for in filename + file_name : str + Filename to search within + + Returns + ------- + str + Extracted string with leading zeros stripped + + Raises + ------ + ValueError + If pattern is not found in filename + """ match = re.search(pattern, file_name) if match: return match.group() @@ -78,9 +100,19 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: def _read_norm_meta(fov: Position) -> NormMeta | None: - """ - Read normalization metadata from the FOV. + """Read normalization metadata from the FOV. + Convert to float32 tensors to avoid automatic casting to float64. + + Parameters + ---------- + fov : Position + OME-Zarr Position object containing metadata + + Returns + ------- + NormMeta | None + Normalization metadata dictionary or None if not available """ norm_meta = fov.zattrs.get("normalization", None) if norm_meta is None: @@ -97,15 +129,21 @@ def _read_norm_meta(fov: Position) -> NormMeta | None: class SlidingWindowDataset(Dataset): - """Torch dataset where each element is a window of - (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. + """Torch dataset where each element is a window of (C, Z, Y, X). - :param list[Position] positions: FOVs to include in dataset - :param ChannelMap channels: source and target channel names, + Where C=2 (source and target) and Z is ``z_window_size``. + + Parameters + ---------- + positions : list[Position] + FOVs to include in dataset + channels : ChannelMap + Source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` - :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param DictTransform | None transform: - a callable that transforms data, defaults to None + z_window_size : int + Z window size of the 2.5D U-Net, 1 for 2D + transform : DictTransform | None, optional + A callable that transforms data, by default None """ def __init__( @@ -131,8 +169,10 @@ def __init__( self._get_windows() def _get_windows(self) -> None: - """Count the sliding windows along T and Z, - and build an index-to-window LUT.""" + """Count the sliding windows along T and Z. + + And build an index-to-window LUT. + """ w = 0 self.window_keys = [] self.window_arrays = [] @@ -163,16 +203,28 @@ def _find_window(self, index: int) -> tuple[ImageArray, int, NormMeta | None]: def _read_img_window( self, img: ImageArray, ch_idx: list[int], tz: int - ) -> tuple[list[Tensor], HCSStackIndex]: + ) -> tuple[tuple[Tensor, ...], HCSStackIndex]: """Read image window as tensor. - :param ImageArray img: NGFF image array - :param list[int] ch_idx: list of channel indices to read, - output channel ordering will reflect the sequence - :param int tz: window index within the FOV, counted Z-first - :return list[Tensor], HCSStackIndex: + Parameters + ---------- + img : ImageArray + NGFF image array + ch_idx : list[int] + list of channel indices to read, output channel ordering will reflect the sequence + tz : int + window index within the FOV, counted Z-first + + Returns + ------- + tuple[tuple[Tensor], HCSStackIndex] list of (C=1, Z, Y, X) image tensors, tuple of image name, time index, and Z index + + Raises + ------ + IndexError + If the window index is out of bounds """ zs = img.shape[-3] - self.z_window_size + 1 t = (tz + zs) // zs - 1 @@ -182,7 +234,7 @@ def _read_img_window( [int(i) for i in ch_idx], slice(z, z + self.z_window_size), ].astype(np.float32) - return torch.from_numpy(data).unbind(dim=1), (img.name, t, z) + return torch.from_numpy(data).unbind(dim=1), HCSStackIndex(img.name, t, z) def __len__(self) -> int: return self._max_window @@ -233,19 +285,27 @@ def __getitem__(self, index: int) -> Sample: class MaskTestDataset(SlidingWindowDataset): - """Torch dataset where each element is a window of - (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. - This a testing stage version of :py:class:`viscy.data.hcs.SlidingWindowDataset`, - and can only be used with batch size 1 for efficiency (no padding for collation), - since the mask is not available for each stack. - - :param list[Position] positions: FOVs to include in dataset - :param ChannelMap channels: source and target channel names, + """Torch dataset with ground truth masks for testing. + + Each element is a window of (C, Z, Y, X) where C=2 (source and target) + and Z is ``z_window_size``. This is a testing stage version of + :py:class:`viscy.data.hcs.SlidingWindowDataset`, and can only be used + with batch size 1 for efficiency (no padding for collation), since the + mask is not available for each stack. + + Parameters + ---------- + positions : list[Position] + FOVs to include in dataset + channels : ChannelMap + Source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` - :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param DictTransform transform: - a callable that transforms data, defaults to None - :param str | None ground_truth_masks: path to the ground truth masks + z_window_size : int + Z window size of the 2.5D U-Net, 1 for 2D + transform : DictTransform | None, optional + A callable that transforms data, by default None + ground_truth_masks : str | None, optional + Path to the ground truth masks, by default None """ def __init__( @@ -367,7 +427,14 @@ def __init__( self.pin_memory = pin_memory @property - def cache_path(self): + def cache_path(self) -> Path: + """Get the temporary cache path for HCS data. + + Returns + ------- + Path + Cache directory path in system temp with SLURM job ID if available + """ return Path( tempfile.gettempdir(), os.getenv("SLURM_JOB_ID", "viscy_cache"), @@ -375,7 +442,14 @@ def cache_path(self): ) @property - def maybe_cached_data_path(self): + def maybe_cached_data_path(self) -> Path: + """Get data path, using cache if caching is enabled. + + Returns + ------- + Path + Cache path if caching enabled, otherwise original data path + """ return self.cache_path if self.caching else self.data_path def _data_log_path(self) -> Path: @@ -387,7 +461,12 @@ def _data_log_path(self) -> Path: log_dir.mkdir(parents=True, exist_ok=True) return log_dir / "data.log" - def prepare_data(self): + def prepare_data(self) -> None: + """Prepare HCS data by caching if enabled. + + Copies OME-Zarr data to temporary cache directory for improved + I/O performance during training. + """ if not self.caching: return # setup logger @@ -424,6 +503,18 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]: } def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + """Set up datasets for the specified Lightning stage. + + Parameters + ---------- + stage : Literal["fit", "validate", "test", "predict"] + Current training stage for Lightning setup + + Raises + ------ + NotImplementedError + If stage is not supported + """ dataset_settings = self._base_dataset_settings if stage in ("fit", "validate"): self._setup_fit(dataset_settings) @@ -546,7 +637,15 @@ def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample batch["target"] = batch["target"][:, :, slice(z_index, z_index + 1)] return batch - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: + """Create training DataLoader for HCS data. + + Returns + ------- + DataLoader + Training DataLoader with shuffling, batch collation, and + multi-worker support for HCS sliding window sampling + """ return DataLoader( self.train_dataset, batch_size=self.batch_size // self.train_patches_per_stack, @@ -559,7 +658,15 @@ def train_dataloader(self): pin_memory=self.pin_memory, ) - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: + """Create validation DataLoader for HCS data. + + Returns + ------- + DataLoader + Validation DataLoader without shuffling for deterministic + validation evaluation on HCS datasets + """ return DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -570,7 +677,15 @@ def val_dataloader(self): pin_memory=self.pin_memory, ) - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: + """Create test DataLoader for HCS data with optional ground truth masks. + + Returns + ------- + DataLoader + Test DataLoader with batch size 1 for mask compatibility + and optional ground truth mask loading for segmentation metrics + """ return DataLoader( self.test_dataset, batch_size=1, @@ -578,7 +693,15 @@ def test_dataloader(self): shuffle=False, ) - def predict_dataloader(self): + def predict_dataloader(self) -> DataLoader: + """Create prediction DataLoader for HCS data. + + Returns + ------- + DataLoader + Prediction DataLoader for inference on HCS datasets + with metadata tracking enabled for transform inversion + """ return DataLoader( self.predict_dataset, batch_size=self.batch_size, @@ -587,8 +710,16 @@ def predict_dataloader(self): ) def _fit_transform(self) -> tuple[Compose, Compose]: - """(normalization -> maybe augmentation -> center crop) - Deterministic center crop as the last step of training and validation.""" + """Create training and validation transform pipelines. + + (normalization -> maybe augmentation -> center crop) + Deterministic center crop as the last step of training and validation. + + Returns + ------- + tuple[Compose, Compose] + Training and validation transform compositions + """ # TODO: These have a fixed order for now... () final_crop = [ CenterSpatialCropd( @@ -607,8 +738,16 @@ def _fit_transform(self) -> tuple[Compose, Compose]: return train_transform, val_transform def _train_transform(self) -> list[Callable]: - """Setup training augmentations: check input values, - and parse the number of Z slices and patches to sample per stack.""" + """Setup training augmentations. + + Check input values and parse the number of Z slices and patches to + sample per stack. + + Returns + ------- + list[Callable] + List of training augmentation transforms + """ self.train_patches_per_stack = 1 z_scale_range = None if self.augmentations: diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index e8da1eb45..d7134fe87 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -124,6 +124,42 @@ def __getitem__(self, idx: int) -> Sample: class LiveCellDataModule(GPUTransformDataModule): + """Data module for LiveCell microscopy dataset. + + Provides train, validation, and test dataloaders for the LiveCell + dataset containing single-cell segmentation annotations for multiple + cell types in live-cell imaging. + + Parameters + ---------- + train_val_images : Path | None, optional + Path to the training and validation images. + test_images : Path | None, optional + Path to the test images. + train_annotations : Path | None, optional + Path to the training annotations. + val_annotations : Path | None, optional + Path to the validation annotations. + test_annotations : Path | None, optional + Path to the test annotations. + train_cpu_transforms : list[MapTransform], optional + List of CPU transforms for training. + val_cpu_transforms : list[MapTransform], optional + List of CPU transforms for validation. + train_gpu_transforms : list[MapTransform], optional + List of GPU transforms for training. + val_gpu_transforms : list[MapTransform], optional + List of GPU transforms for validation. + test_transforms : list[MapTransform], optional + List of transforms for testing. + batch_size : int, optional + Batch size, by default 16. + num_workers : int, optional + Number of dataloading workers, by default 8. + pin_memory : bool, optional + Pin memory for dataloaders, by default True. + """ + def __init__( self, train_val_images: Path | None = None, @@ -172,21 +208,56 @@ def __init__( @property def train_cpu_transforms(self) -> Compose: + """Get CPU transforms for training data augmentation. + + Returns + ------- + Compose + Composed transforms applied on CPU during training. + """ return self._train_cpu_transforms @property def val_cpu_transforms(self) -> Compose: + """Get CPU transforms for validation data processing. + + Returns + ------- + Compose + Composed transforms applied on CPU during validation. + """ return self._val_cpu_transforms @property def train_gpu_transforms(self) -> Compose: + """Get GPU transforms for training data augmentation. + + Returns + ------- + Compose + Composed transforms applied on GPU during training. + """ return self._train_gpu_transforms @property def val_gpu_transforms(self) -> Compose: + """Get GPU transforms for validation data processing. + + Returns + ------- + Compose + Composed transforms applied on GPU during validation. + """ return self._val_gpu_transforms def setup(self, stage: str) -> None: + """Set up datasets based on the specified stage. + + Parameters + ---------- + stage : str + Either "fit" for training/validation or "test" for testing. + """ if stage == "fit": self._setup_fit() elif stage == "test": @@ -221,6 +292,13 @@ def _setup_test(self) -> None: ) def test_dataloader(self) -> DataLoader: + """Create test data loader. + + Returns + ------- + DataLoader + Test data loader with LiveCell test dataset. + """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers ) diff --git a/viscy/data/mmap_cache.py b/viscy/data/mmap_cache.py index 735159903..b3cf427f4 100644 --- a/viscy/data/mmap_cache.py +++ b/viscy/data/mmap_cache.py @@ -30,6 +30,31 @@ class MmappedDataset(Dataset): + """Dataset for memory-mapped OME-Zarr arrays with caching. + + Provides efficient access to time-series microscopy data through + memory-mapped tensors with lazy loading and caching capabilities. + + Parameters + ---------- + positions : list[Position] + List of FOVs to load images from. + channel_names : list[str] + List of channel names to load. + cache_map : DictProxy + Shared dictionary for caching loaded volumes. + buffer : MemoryMappedTensor + Memory-mapped tensor for caching loaded volumes. + preprocess_transforms : Compose | None, optional + Composed transforms to be applied on the CPU, by default None + cpu_transform : Compose | None, optional + Composed transforms to be applied on the CPU, by default None + array_key : str, optional + The image array key name (multi-scale level), by default "0" + load_normalization_metadata : bool, optional + Load normalization metadata in the sample dictionary, by default True + """ + def __init__( self, positions: list[Position], @@ -178,26 +203,71 @@ def __init__( @property def preprocessing_transforms(self) -> Compose: + """Get preprocessing transforms for data normalization. + + Returns + ------- + Compose + Composed transforms for preprocessing image data. + """ return self._preprocessing_transforms @property def train_cpu_transforms(self) -> Compose: + """Get CPU transforms for training data augmentation. + + Returns + ------- + Compose + Composed transforms applied on CPU during training. + """ return self._train_cpu_transforms @property def train_gpu_transforms(self) -> Compose: + """Get GPU transforms for training data augmentation. + + Returns + ------- + Compose + Composed transforms applied on GPU during training. + """ return self._train_gpu_transforms @property def val_cpu_transforms(self) -> Compose: + """Get CPU transforms for validation data processing. + + Returns + ------- + Compose + Composed transforms applied on CPU during validation. + """ return self._val_cpu_transforms @property def val_gpu_transforms(self) -> Compose: + """Get GPU transforms for validation data processing. + + Returns + ------- + Compose + Composed transforms applied on GPU during validation. + """ return self._val_gpu_transforms @property def cache_dir(self) -> Path: + """Get cache directory for memory-mapped files. + + Creates a unique cache directory based on SLURM job ID or + distributed rank for parallel training. + + Returns + ------- + Path + Cache directory path for storing memory-mapped tensor files. + """ scratch_dir = self.scratch_dir or Path(tempfile.gettempdir()) cache_dir = Path( scratch_dir, @@ -222,6 +292,22 @@ def _buffer_shape(self, arr_shape, fovs) -> tuple[int, ...]: return (len(fovs) * arr_shape[0], len(self.channels), *arr_shape[2:]) def setup(self, stage: Literal["fit", "validate"]) -> None: + """Set up datasets for training or validation. + + Creates memory-mapped datasets with train/val split based on the + specified stage. Initializes buffers and cache maps for efficient + data loading. + + Parameters + ---------- + stage : Literal["fit", "validate"] + Stage for which to set up the datasets. + + Raises + ------ + NotImplementedError + If stage is not "fit" or "validate". + """ if stage not in ("fit", "validate"): raise NotImplementedError("Only fit and validate stages are supported.") plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs") diff --git a/viscy/data/segmentation.py b/viscy/data/segmentation.py index 553d9241c..4a6f13716 100644 --- a/viscy/data/segmentation.py +++ b/viscy/data/segmentation.py @@ -15,6 +15,29 @@ class SegmentationDataset(Dataset): + """ + Dataset for segmentation evaluation tasks. + + Loads predicted and target segmentation masks for comparison and evaluation. + + Parameters + ---------- + pred_dataset : Plate + HCS OME-Zarr plate containing predicted segmentation masks. + target_dataset : Plate + HCS OME-Zarr plate containing ground truth segmentation masks. + pred_channel : str + Name of the prediction channel to load. + target_channel : str + Name of the target channel to load. + pred_z_slice : int or slice + Z slice selection for prediction data. + target_z_slice : int or slice + Z slice selection for target data. + img_name : str, optional + Name of the image array within positions, by default "0". + """ + def __init__( self, pred_dataset: Plate, @@ -50,9 +73,23 @@ def _build_indices(self) -> None: _logger.info(f"Number of test samples: {len(self)}") def __len__(self) -> int: + """Return the number of segmentation samples in the dataset.""" return len(self._indices) def __getitem__(self, idx: int) -> SegmentationSample: + """ + Get a segmentation sample pair. + + Parameters + ---------- + idx : int + Index of the sample to retrieve. + + Returns + ------- + SegmentationSample + Dictionary containing prediction, target, position index, and time index. + """ pred_img, target_img, p, t = self._indices[idx] _logger.debug(f"Target image: {target_img.name}") pred = torch.from_numpy( @@ -65,6 +102,32 @@ def __getitem__(self, idx: int) -> SegmentationSample: class SegmentationDataModule(LightningDataModule): + """ + Lightning DataModule for segmentation evaluation. + + Manages data loading for comparing predicted and target segmentation masks. + Only supports test stage for evaluation purposes. + + Parameters + ---------- + pred_dataset : Path + Path to HCS OME-Zarr containing predicted segmentation masks. + target_dataset : Path + Path to HCS OME-Zarr containing ground truth segmentation masks. + pred_channel : str + Name of the prediction channel to load. + target_channel : str + Name of the target channel to load. + pred_z_slice : int + Z slice index for prediction data. + target_z_slice : int + Z slice index for target data. + batch_size : int + Batch size for data loading. + num_workers : int + Number of workers for data loading. + """ + def __init__( self, pred_dataset: Path, @@ -87,6 +150,19 @@ def __init__( self.num_workers = num_workers def setup(self, stage: str) -> None: + """ + Set up the segmentation dataset. + + Parameters + ---------- + stage : str + Stage to set up for. Only "test" is supported. + + Raises + ------ + NotImplementedError + If stage is not "test". + """ if stage != "test": raise NotImplementedError("Only test stage is supported!") self.test_dataset = SegmentationDataset( @@ -99,6 +175,14 @@ def setup(self, stage: str) -> None: ) def test_dataloader(self) -> DataLoader: + """ + Create test data loader for segmentation evaluation. + + Returns + ------- + DataLoader + Test data loader containing prediction-target pairs. + """ return DataLoader( self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers ) diff --git a/viscy/data/select.py b/viscy/data/select.py index 6e00c10e8..4a0e4539b 100644 --- a/viscy/data/select.py +++ b/viscy/data/select.py @@ -1,4 +1,4 @@ -from typing import Generator +from collections.abc import Generator from iohub.ngff.nodes import Plate, Position, Well @@ -21,6 +21,12 @@ def _filter_fovs( class SelectWell: + """Filter wells and fields-of-view for dataset selection. + + This class provides functionality to filter wells by inclusion criteria + and exclude specific fields-of-view from the dataset. + """ + _include_wells: list[str] | None _exclude_fovs: list[str] | None diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 72828af42..1392f83f7 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Sequence from pathlib import Path -from typing import Literal, Sequence +from typing import Literal import numpy as np import pandas as pd @@ -66,6 +67,13 @@ def _transform_channel_wise( class TripletDataset(Dataset): + """Dataset for triplet sampling of tracked cells. + + Generates anchor, positive, and negative triplets from tracked cell + patches for contrastive learning. Supports temporal sampling with + configurable time intervals. + """ + def __init__( self, positions: list[Position], @@ -213,8 +221,21 @@ def _sample_positives(self, anchor_rows: pd.DataFrame) -> pd.DataFrame: return query.merge(self.tracks, on=["global_track_id", "t"], how="inner") def _sample_negative(self, anchor_row: pd.Series) -> pd.Series: - """Select a negative sample from a different track in the next time point - if an interval is specified, otherwise from any random time point.""" + """Select a negative sample from a different track. + + Selects from the next time point if an interval is specified, + otherwise from any random time point. + + Parameters + ---------- + anchor_row : pd.Series + Row containing anchor cell information. + + Returns + ------- + pd.Series + Row containing negative sample information. + """ if self.time_interval == "any": tracks = self.tracks else: @@ -324,10 +345,17 @@ def __getitems__(self, indices: list[int]) -> list[TripletSample]: class TripletDataModule(HCSDataModule): + """Lightning data module for triplet sampling from tracked cells. + + Provides train, validation, and prediction dataloaders for contrastive + learning on cell tracking data. Supports configurable time intervals + and spatial patch sampling. + """ + def __init__( self, - data_path: str, - tracks_path: str, + data_path: str | Path, + tracks_path: str | Path, source_channel: str | Sequence[str], z_range: tuple[int, int], initial_yx_patch_size: tuple[int, int] = (512, 512), @@ -354,9 +382,9 @@ def __init__( Parameters ---------- - data_path : str + data_path : str | Path Image dataset path - tracks_path : str + tracks_path : str | Path Tracks labels dataset path source_channel : str | Sequence[str] List of input channel names @@ -436,13 +464,15 @@ def __init__( def _align_tracks_tables_with_positions( self, ) -> tuple[list[Position], list[pd.DataFrame]]: - """Parse positions in ome-zarr store containing tracking information - and assemble tracks tables for each position. + """Parse positions in ome-zarr store containing tracking information. + + Assembles tracks tables for each position by matching position names + with corresponding CSV files in the tracks directory. Returns ------- tuple[list[Position], list[pd.DataFrame]] - List of positions and list of tracks tables for each position + List of positions and list of tracks tables for each position. """ positions = [] tracks_tables = [] @@ -466,7 +496,7 @@ def _base_dataset_settings(self) -> dict: } def _update_to_device_transform(self): - "Make sure that GPU transforms are set to the current device." + """Make sure that GPU transforms are set to the current device.""" for transform in self.normalizations + self.augmentations: if isinstance(transform, ToDeviced): transform.converter.device = torch.device( @@ -535,7 +565,14 @@ def _setup_predict(self, dataset_settings: dict): def _setup_test(self, *args, **kwargs): raise NotImplementedError("Self-supervised model does not support testing") - def train_dataloader(self): + def train_dataloader(self) -> ThreadDataLoader: + """Create training data loader for triplet sampling. + + Returns + ------- + ThreadDataLoader + Training data loader with shuffling and thread workers. + """ return ThreadDataLoader( self.train_dataset, use_thread_workers=True, @@ -548,7 +585,14 @@ def train_dataloader(self): pin_memory=self.pin_memory, ) - def val_dataloader(self): + def val_dataloader(self) -> ThreadDataLoader: + """Create validation data loader for triplet sampling. + + Returns + ------- + ThreadDataLoader + Validation data loader without shuffling. + """ return ThreadDataLoader( self.val_dataset, use_thread_workers=True, @@ -561,7 +605,14 @@ def val_dataloader(self): pin_memory=self.pin_memory, ) - def predict_dataloader(self): + def predict_dataloader(self) -> ThreadDataLoader: + """Create prediction data loader for cell embedding extraction. + + Returns + ------- + ThreadDataLoader + Prediction data loader for anchor-only sampling. + """ return ThreadDataLoader( self.predict_dataset, use_thread_workers=True, diff --git a/viscy/data/typing.py b/viscy/data/typing.py index d6a70488c..3d800953a 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -1,4 +1,7 @@ -from typing import Callable, NamedTuple, Sequence, TypedDict, TypeVar +"""Type definitions for VisCy data modules and structures.""" + +from collections.abc import Callable, Sequence +from typing import NamedTuple, TypedDict, TypeVar from torch import ShortTensor, Tensor @@ -13,6 +16,8 @@ class LevelNormStats(TypedDict): + """Statistics for normalization at a specific level (dataset or FOV).""" + mean: Tensor std: Tensor median: Tensor @@ -20,6 +25,8 @@ class LevelNormStats(TypedDict): class ChannelNormStats(TypedDict): + """Normalization statistics for a channel at different levels.""" + dataset_statistics: LevelNormStats fov_statistics: LevelNormStats @@ -39,6 +46,7 @@ class HCSStackIndex(NamedTuple): class Sample(TypedDict, total=False): """ Image sample type for mini-batches. + All fields are optional. """ @@ -54,9 +62,7 @@ class Sample(TypedDict, total=False): class SegmentationSample(TypedDict): - """ - Segmentation sample type for mini-batches. - """ + """Segmentation sample type for mini-batches.""" pred: ShortTensor target: ShortTensor @@ -72,17 +78,18 @@ class ChannelMap(TypedDict): class TrackingIndex(TypedDict): - """Tracking index extracted from ultrack result - Potentially collated by the dataloader""" + """ + Tracking index extracted from ultrack result. + + Potentially collated by the dataloader. + """ fov_name: OneOrSeq[str] id: OneOrSeq[int] class TripletSample(TypedDict): - """ - Triplet sample type for mini-batches. - """ + """Triplet sample type for mini-batches.""" anchor: Tensor positive: NotRequired[Tensor] diff --git a/viscy/preprocessing/generate_masks.py b/viscy/preprocessing/generate_masks.py index 491bc4069..8614ee13e 100644 --- a/viscy/preprocessing/generate_masks.py +++ b/viscy/preprocessing/generate_masks.py @@ -1,40 +1,47 @@ """Generate masks from sum of flurophore channels""" -import iohub.ngff as ngff +from pathlib import Path +from typing import Literal +import iohub.ngff as ngff import viscy.utils.aux_utils as aux_utils from viscy.utils.mp_utils import mp_create_and_write_mask class MaskProcessor: - """ - Appends Masks to zarr directories - """ + """Appends Masks to zarr directories""" def __init__( self, - zarr_dir, - channel_ids, - time_ids=-1, - pos_ids=-1, - num_workers=4, - mask_type="otsu", - overwrite_ok=False, + zarr_dir: Path, + channel_ids: list[int] | int, + time_ids: list[int] | int, + pos_ids: list[int] | int, + num_workers: int = 4, + mask_type: Literal[ + "otsu", "unimodal", "mem_detection", "borders_weight_loss_map" + ] = "otsu", + overwrite_ok: bool = False, ): - """ - :param str zarr_dir: directory of HCS zarr store to pull data from. - Note: data in store is assumed to be stored in - (time, channel, z, y, x) format. - :param list[int] channel_ids: Channel indices to be masked (typically - just one) - :param int/list channel_ids: generate mask from the sum of these - (flurophore) channel indices - :param list/int time_ids: timepoints to consider - :param int pos_ids: Position (FOV) indices to use - :param int num_workers: number of workers for multiprocessing - :param str mask_type: method to use for generating mask. Needed for - mapping to the masking function. One of: - {'otsu', 'unimodal', 'borders_weight_loss_map'} + """Initialize mask processor for generating masks from fluorophore channels. + + Parameters + ---------- + zarr_dir : str + Directory of HCS zarr store to pull data from. Note: data in store is assumed to be stored in TCZYX format. + channel_ids : list[int] | int + Channel indices to be masked (typically just one) + time_ids : list[int] | int + Timepoints to consider + pos_ids : list[int] | int + Position (FOV) indices to use + num_workers : int + Number of workers for multiprocessing + mask_type : str + Method to use for generating mask. Needed for mapping to the masking function. + One of: {'otsu', 'unimodal', 'mem_detection', 'borders_weight_loss_map'}. Default is 'otsu'. + overwrite_ok : bool + Overwrite existing masks. Default is False. """ self.zarr_dir = zarr_dir self.num_workers = num_workers @@ -72,8 +79,9 @@ def __init__( print(f"Mask found in channel {mask_name}. Overwriting with this mask.") plate.close() - def generate_masks(self, structure_elem_radius=5): - """ + def generate_masks(self, structure_elem_radius: int = 5): + """Generate foreground masks from fluorophore channels. + The sum of flurophore channels is thresholded to generate a foreground mask. @@ -84,10 +92,11 @@ def generate_masks(self, structure_elem_radius=5): Masks are also saved as an additional untracked array named "mask" and tracked in the "mask" metadata field. - :param int structure_elem_radius: Radius of structuring element for - morphological operations + Parameters + ---------- + structure_elem_radius : int + Radius of structuring element for morphological operations """ - # Gather function arguments for each index pair at each position plate = ngff.open_ome_zarr(store_path=self.zarr_dir, mode="r+") diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py index 29c2ed419..b7701849c 100644 --- a/viscy/preprocessing/pixel_ratio.py +++ b/viscy/preprocessing/pixel_ratio.py @@ -7,12 +7,22 @@ def sematic_class_weights( dataset_path: str, target_channel: str, num_classes: int = 3 ) -> NDArray: """Computes class balancing weights for semantic segmentation. + The weights can be used for cross-entropy loss. - :param str dataset_path: HCS OME-Zarr dataset path - :param str target_channel: target channel name - :param int num_classes: number of classes - :return NDArray: inverted ratio of background, uninfected and infected pixels + Parameters + ---------- + dataset_path : str + HCS OME-Zarr dataset path + target_channel : str + Target channel name + num_classes : int + Number of classes. Default is 3. + + Returns + ------- + NDArray + Inverted ratio of background, uninfected and infected pixels """ dataset = open_ome_zarr(dataset_path) arrays = [da.from_zarr(pos["0"]) for _, pos in dataset.positions()] diff --git a/viscy/preprocessing/precompute.py b/viscy/preprocessing/precompute.py index 1c68ad300..a23aa1e57 100644 --- a/viscy/preprocessing/precompute.py +++ b/viscy/preprocessing/precompute.py @@ -8,7 +8,6 @@ import dask.array as da from dask.diagnostics import ProgressBar from iohub.ngff import open_ome_zarr - from viscy.data.select import _filter_fovs, _filter_wells @@ -40,6 +39,27 @@ def precompute_array( include_wells: list[str] | None = None, exclude_fovs: list[str] | None = None, ) -> None: + """Precompute normalized image arrays for efficient data loading. + + Parameters + ---------- + data_path : Path + Path to HCS OME-Zarr dataset. + output_path : Path + Output path for precomputed arrays. + channel_names : list[str] + List of channel names to process. + subtrahends : list[Literal["mean"] | float] + Subtraction values for normalization per channel. + divisors : list[Literal["std"] | tuple[float, float]] + Division values for normalization per channel. + image_array_key : str, optional + Array key in zarr store, by default "0". + include_wells : list[str] | None, optional + Wells to include, by default None (all wells). + exclude_fovs : list[str] | None, optional + FOVs to exclude, by default None (no exclusions). + """ normalized_images: list[da.Array] = [] with open_ome_zarr(data_path, layout="hcs", mode="r") as dataset: channel_indices = [dataset.channel_names.index(c) for c in channel_names] diff --git a/viscy/representation/classification.py b/viscy/representation/classification.py index 0b4ed58a8..27aac722b 100644 --- a/viscy/representation/classification.py +++ b/viscy/representation/classification.py @@ -1,24 +1,50 @@ from pathlib import Path +from typing import Any +import numpy as np import pandas as pd import torch -from lightning.pytorch import LightningModule +from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter from torch import nn from torchmetrics.functional.classification import binary_accuracy, binary_f1_score - from viscy.representation.contrastive import ContrastiveEncoder from viscy.utils.log_images import render_images class ClassificationPredictionWriter(BasePredictionWriter): - def __init__(self, output_path: Path): + """Prediction writer callback for saving classification outputs to CSV. + + Collects predictions from all batches and writes them to a CSV file at the + end of each epoch. Converts tensor outputs to numpy arrays for storage. + """ + + def __init__(self, output_path: Path) -> None: super().__init__("epoch") if Path(output_path).exists(): raise FileExistsError(f"Output path {output_path} already exists.") self.output_path = output_path - def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): + def write_on_epoch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + predictions: list[dict[str, Any]], + batch_indices: list[int], + ) -> None: + """Write all predictions to CSV file at epoch end. + + Parameters + ---------- + trainer : lightning.Trainer + PyTorch Lightning trainer instance. + pl_module : lightning.LightningModule + Lightning module being trained. + predictions : list + List of prediction dictionaries from all batches. + batch_indices : list + Indices of batches processed during prediction. + """ all_predictions = [] for prediction in predictions: for key, value in prediction.items(): @@ -29,12 +55,19 @@ def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices): class ClassificationModule(LightningModule): + """Binary classification module using pre-trained contrastive encoder. + + Adapts a contrastive encoder for binary classification by replacing the + final linear layer and adding classification-specific training logic. + Computes binary cross-entropy loss and tracks accuracy and F1-score metrics. + """ + def __init__( self, encoder: ContrastiveEncoder, lr: float | None, loss: nn.Module | None = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1.0)), - ): + ) -> None: super().__init__() self.stem = encoder.stem self.backbone = encoder.encoder @@ -43,15 +76,34 @@ def __init__( self.lr = lr self.example_input_array = torch.rand(2, 1, 15, 160, 160) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through stem and backbone for classification. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, channels, depth, height, width). + + Returns + ------- + torch.Tensor + Logits tensor of shape (batch_size, 1) for binary classification. + """ x = self.stem(x) return self.backbone(x) - def on_fit_start(self): + def on_fit_start(self) -> None: + """Initialize example storage lists at start of training. + + Creates empty lists to store training and validation examples for + visualization logging during the training process. + """ self.train_examples = [] self.val_examples = [] - def _fit_step(self, batch, stage: str, loss_on_step: bool): + def _fit_step( + self, batch: tuple[torch.Tensor, torch.Tensor], stage: str, loss_on_step: bool + ) -> tuple[torch.Tensor, np.ndarray]: x, y = batch y_hat = self(x) loss = self.loss(y_hat, y) @@ -65,26 +117,79 @@ def _fit_step(self, batch, stage: str, loss_on_step: bool): ) return loss, x[0, 0, x.shape[2] // 2].detach().cpu().numpy() - def training_step(self, batch, batch_idx: int): + def training_step( + self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Execute single training step with loss computation and logging. + + Parameters + ---------- + batch : tuple + Training batch containing (inputs, targets). + batch_idx : int + Index of current batch within epoch. + + Returns + ------- + torch.Tensor + Training loss for backpropagation. + """ loss, example = self._fit_step(batch, "train", loss_on_step=True) if batch_idx < 4: self.train_examples.append([example]) return loss - def validation_step(self, batch, batch_idx: int): + def validation_step( + self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Execute single validation step with metrics computation. + + Parameters + ---------- + batch : tuple + Validation batch containing (inputs, targets). + batch_idx : int + Index of current batch within epoch. + + Returns + ------- + torch.Tensor + Validation loss for monitoring. + """ loss, example = self._fit_step(batch, "val", loss_on_step=False) if batch_idx < 4: self.val_examples.append([example]) return loss - def predict_step(self, batch, batch_idx: int, dataloader_idx: int | None = None): + def predict_step( + self, + batch: tuple[torch.Tensor, torch.Tensor, dict[str, Any]], + batch_idx: int, + dataloader_idx: int | None = None, + ) -> dict[str, Any]: + """Execute prediction step with sigmoid activation for probabilities. + + Parameters + ---------- + batch : tuple + Prediction batch containing (inputs, targets, indices). + batch_idx : int + Index of current batch. + dataloader_idx : int or None, optional + Index of dataloader when multiple dataloaders used. + + Returns + ------- + dict + Dictionary containing indices, labels, and sigmoid probabilities. + """ x, y, indices = batch y_hat = nn.functional.sigmoid(self(x)) indices["label"] = y indices["prediction"] = y_hat return indices - def _log_images(self, examples, stage): + def _log_images(self, examples: list[list[np.ndarray]], stage: str) -> None: image = render_images(examples) self.logger.experiment.add_image( f"{stage}/examples", @@ -93,13 +198,30 @@ def _log_images(self, examples, stage): dataformats="HWC", ) - def on_train_epoch_end(self): + def on_train_epoch_end(self) -> None: + """Log training examples and clear storage at epoch end. + + Renders and logs training examples to tensorboard, then clears the + examples list for the next epoch. + """ self._log_images(self.train_examples, "train") self.train_examples.clear() - def on_validation_epoch_end(self): + def on_validation_epoch_end(self) -> None: + """Log validation examples and clear storage at epoch end. + + Renders and logs validation examples to tensorboard, then clears the + examples list for the next epoch. + """ self._log_images(self.val_examples, "val") self.val_examples.clear() - def configure_optimizers(self): + def configure_optimizers(self) -> torch.optim.AdamW: + """Configure AdamW optimizer for training. + + Returns + ------- + torch.optim.AdamW + AdamW optimizer with specified learning rate. + """ return torch.optim.AdamW(self.parameters(), lr=self.lr) diff --git a/viscy/representation/contrastive.py b/viscy/representation/contrastive.py index 8edeb8623..df6094ee9 100644 --- a/viscy/representation/contrastive.py +++ b/viscy/representation/contrastive.py @@ -3,7 +3,6 @@ import timm import torch.nn as nn from torch import Tensor - from viscy.unet.networks.unext2 import StemDepthtoChannels diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 7b6296356..d4b1e9f62 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Sequence from pathlib import Path -from typing import Any, Dict, Literal, Optional, Sequence +from typing import Any, Literal import numpy as np import pandas as pd @@ -8,8 +9,6 @@ from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import NDArray -from xarray import Dataset, open_zarr - from viscy.data.triplet import INDEX_COLUMNS from viscy.representation.engine import ContrastivePrediction from viscy.representation.evaluation.dimensionality_reduction import ( @@ -17,14 +16,15 @@ compute_pca, compute_phate, ) +from xarray import Dataset, open_zarr __all__ = ["read_embedding_dataset", "EmbeddingWriter", "write_embedding_dataset"] _logger = logging.getLogger("lightning.pytorch") def read_embedding_dataset(path: Path) -> Dataset: - """ - Read the embedding dataset written by the EmbeddingWriter callback. + """Read the embedding dataset written by the EmbeddingWriter callback. + Supports both legacy datasets (without x/y coordinates) and new datasets. Parameters @@ -63,10 +63,10 @@ def write_embedding_dataset( output_path: Path, features: np.ndarray, index_df: pd.DataFrame, - projections: Optional[np.ndarray] = None, - umap_kwargs: Optional[Dict[str, Any]] = None, - phate_kwargs: Optional[Dict[str, Any]] = None, - pca_kwargs: Optional[Dict[str, Any]] = None, + projections: np.ndarray | None = None, + umap_kwargs: dict[str, Any] | None = None, + phate_kwargs: dict[str, Any] | None = None, + pca_kwargs: dict[str, Any] | None = None, overwrite: bool = False, ) -> None: """ @@ -221,6 +221,7 @@ def __init__( self.overwrite = overwrite def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """Initialize prediction writing and validate output path.""" if self.output_path.exists(): raise FileExistsError(f"Output path {self.output_path} already exists.") _logger.debug(f"Writing embeddings to {self.output_path}") diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 7a35d93f1..1e7ae93aa 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,5 +1,6 @@ import logging -from typing import Literal, Sequence, TypedDict +from collections.abc import Sequence +from typing import Literal, TypedDict import numpy as np import torch @@ -8,7 +9,6 @@ from pytorch_metric_learning.losses import NTXentLoss from torch import Tensor, nn from umap import UMAP - from viscy.data.typing import TrackingIndex, TripletSample from viscy.representation.contrastive import ContrastiveEncoder from viscy.utils.log_images import detach_sample, render_images @@ -17,6 +17,12 @@ class ContrastivePrediction(TypedDict): + """Typed dictionary for contrastive model predictions. + + Contains features, projections, and metadata for contrastive learning + inference outputs. + """ + features: Tensor projections: Tensor index: TrackingIndex @@ -66,12 +72,34 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: return self.model(x) def log_feature_statistics(self, embeddings: Tensor, prefix: str): + """Log embedding statistics for monitoring training dynamics. + + Parameters + ---------- + embeddings : Tensor + Embedding vectors to analyze. + prefix : str + Prefix for logging keys. + """ mean = torch.mean(embeddings, dim=0).detach().cpu().numpy() std = torch.std(embeddings, dim=0).detach().cpu().numpy() _logger.debug(f"{prefix}_mean: {mean}") _logger.debug(f"{prefix}_std: {std}") def print_embedding_norms(self, anchor, positive, negative, phase): + """Log L2 norms of embeddings for triplet components. + + Parameters + ---------- + anchor : Tensor + Anchor embeddings. + positive : Tensor + Positive embeddings. + negative : Tensor + Negative embeddings. + phase : str + Training phase identifier for logging. + """ anchor_norm = torch.norm(anchor, dim=1).mean().item() positive_norm = torch.norm(positive, dim=1).mean().item() negative_norm = torch.norm(negative, dim=1).mean().item() @@ -133,6 +161,15 @@ def _log_step_samples(self, batch_idx, samples, stage: Literal["train", "val"]): output_list.extend(detach_sample(samples, self.log_samples_per_batch)) def log_embedding_umap(self, embeddings: Tensor, tag: str): + """Log UMAP visualization of embedding space to TensorBoard. + + Parameters + ---------- + embeddings : Tensor + High-dimensional embeddings to visualize. + tag : str + Tag for TensorBoard logging. + """ _logger.debug(f"Computing UMAP for {tag} embeddings.") umap = UMAP(n_components=2) embeddings_np = embeddings.detach().cpu().numpy() @@ -146,6 +183,23 @@ def log_embedding_umap(self, embeddings: Tensor, tag: str): ) def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Execute training step for contrastive learning. + + Computes triplet or NT-Xent loss based on configured loss function + and logs training metrics. + + Parameters + ---------- + batch : TripletSample + Batch containing anchor, positive, and negative samples. + batch_idx : int + Index of current batch. + + Returns + ------- + Tensor + Computed contrastive loss. + """ anchor_img = batch["anchor"] pos_img = batch["positive"] _, anchor_projection = self(anchor_img) @@ -177,6 +231,11 @@ def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: return loss def on_train_epoch_end(self) -> None: + """Log training samples and embeddings at epoch end. + + Logs sample images and optionally computes UMAP visualization + of embedding space for monitoring training progress. + """ super().on_train_epoch_end() self._log_samples("train_samples", self.training_step_outputs) # Log UMAP embeddings for validation @@ -220,6 +279,11 @@ def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: return loss def on_validation_epoch_end(self) -> None: + """Log validation samples and embeddings at epoch end. + + Logs sample images and optionally computes UMAP visualization + of embedding space for monitoring validation performance. + """ super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) # Log UMAP embeddings for training @@ -232,6 +296,13 @@ def on_validation_epoch_end(self) -> None: self.validation_step_outputs = [] def configure_optimizers(self): + """Configure optimizer for contrastive learning. + + Returns + ------- + torch.optim.Optimizer + AdamW optimizer with configured learning rate. + """ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) return optimizer diff --git a/viscy/representation/evaluation/__init__.py b/viscy/representation/evaluation/__init__.py index c474aec82..36d899b2b 100644 --- a/viscy/representation/evaluation/__init__.py +++ b/viscy/representation/evaluation/__init__.py @@ -1,5 +1,6 @@ -""" -This module enables evaluation of learned representations using annotations, such as +"""Evaluation tools for learned representations using various annotation types. + +Enables evaluation of learned representations using annotations, such as: * cell division labels, * infection state labels, * labels predicted using supervised classifiers, @@ -14,18 +15,22 @@ https://github.com/mehta-lab/dynacontrast/blob/master/analysis/gmm.py """ -import pandas as pd +from pathlib import Path +import pandas as pd +import xarray as xr from viscy.data.triplet import TripletDataModule -def load_annotation(da, path, name, categories: dict | None = None): +def load_annotation( + da: xr.DataArray, path: str, name: str, categories: dict | None = None +) -> pd.Series: """ Load annotations from a CSV file and map them to the dataset. Parameters ---------- - da : xarray.DataArray + da : xr.DataArray The dataset array containing 'fov_name' and 'id' coordinates. path : str Path to the CSV file containing annotations. @@ -64,15 +69,41 @@ def load_annotation(da, path, name, categories: dict | None = None): def dataset_of_tracks( - data_path, - tracks_path, - fov_list, - track_id_list, - source_channel=["Phase3D", "RFP"], - z_range=(28, 43), - initial_yx_patch_size=(128, 128), - final_yx_patch_size=(128, 128), + data_path: str | Path, + tracks_path: str | Path, + fov_list: list[str], + track_id_list: list[int], + source_channel: list[str] = ["Phase3D", "RFP"], + z_range: tuple[int, int] = (28, 43), + initial_yx_patch_size: tuple[int, int] = (128, 128), + final_yx_patch_size: tuple[int, int] = (128, 128), ): + """Create a prediction dataset from tracks for evaluation. + + Parameters + ---------- + data_path : str + Path to the data directory containing image files. + tracks_path : str + Path to the tracks data file. + fov_list : list + List of field of view names to include. + track_id_list : list + List of track IDs to include. + source_channel : list, optional + List of source channel names, by default ["Phase3D", "RFP"]. + z_range : tuple, optional + Z-stack range as (start, end), by default (28, 43). + initial_yx_patch_size : tuple, optional + Initial patch size in YX dimensions, by default (128, 128). + final_yx_patch_size : tuple, optional + Final patch size in YX dimensions, by default (128, 128). + + Returns + ------- + Dataset + Configured prediction dataset for evaluation. + """ data_module = TripletDataModule( data_path=data_path, tracks_path=tracks_path, @@ -84,7 +115,7 @@ def dataset_of_tracks( final_yx_patch_size=final_yx_patch_size, batch_size=1, num_workers=16, - normalizations=None, + normalizations=[], predict_cells=True, ) # for train and val diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index ebf49455f..dbdc6455c 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -85,8 +85,9 @@ def select_block(distances: NDArray, index: NDArray) -> NDArray: def compare_time_offset( single_track_distances: NDArray, time_offset: int = 1 ) -> NDArray: - """Extract the nearest neighbor distances/rankings - of the next sample compared to each sample. + """Extract the nearest neighbor distances/rankings of the next sample. + + Compared to each sample. Parameters ---------- @@ -105,12 +106,14 @@ def compare_time_offset( return single_track_distances.diagonal(offset=-time_offset) -def dbscan_clustering(embeddings, eps=0.5, min_samples=5): +def dbscan_clustering(embeddings: NDArray, eps=0.5, min_samples=5): """ Apply DBSCAN clustering to the embeddings. Parameters ---------- + embeddings : NDArray + Embeddings to cluster. eps : float, optional The maximum distance between two samples for them to be considered as in the same neighborhood. Default is 0.5. min_samples : int, optional @@ -118,7 +121,7 @@ def dbscan_clustering(embeddings, eps=0.5, min_samples=5): Returns ------- - np.ndarray + NDArray Clustering labels assigned by DBSCAN. """ dbscan = DBSCAN(eps=eps, min_samples=min_samples) @@ -126,12 +129,16 @@ def dbscan_clustering(embeddings, eps=0.5, min_samples=5): return clusters -def clustering_evaluation(embeddings, annotations, method="nmi"): +def clustering_evaluation(embeddings: NDArray, annotations: NDArray, method="nmi"): """ Evaluate the clustering of the embeddings compared to the ground truth labels. Parameters ---------- + embeddings : NDArray + Embeddings to cluster. + annotations : NDArray + Ground truth labels. method : str, optional Metric to use for evaluation ('nmi' or 'ari'). Default is 'nmi'. diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index eb5d43f91..5b0db1cb7 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -2,14 +2,14 @@ import pandas as pd import umap +import xarray as xr from numpy.typing import NDArray from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler -from xarray import Dataset def compute_phate( - embedding_dataset, + embedding_dataset: NDArray | xr.Dataset, n_components: int = 2, knn: int = 5, decay: int = 40, @@ -21,7 +21,7 @@ def compute_phate( Parameters ---------- - embedding_dataset : xarray.Dataset or NDArray + embedding_dataset : xr.Dataset | NDArray The dataset containing embeddings, timepoints, fov_name, and track_id, or a numpy array of embeddings. n_components : int, optional @@ -55,7 +55,7 @@ def compute_phate( # Get embeddings from dataset if needed embeddings = ( embedding_dataset["features"].values - if isinstance(embedding_dataset, Dataset) + if isinstance(embedding_dataset, xr.Dataset) else embedding_dataset ) @@ -66,7 +66,7 @@ def compute_phate( phate_embedding = phate_model.fit_transform(embeddings) # Update dataset if requested - if update_dataset and isinstance(embedding_dataset, Dataset): + if update_dataset and isinstance(embedding_dataset, xr.Dataset): for i in range( min(2, phate_embedding.shape[1]) ): # Only update PHATE1 and PHATE2 @@ -80,7 +80,7 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True): Parameters ---------- - embedding_dataset : xarray.Dataset or NDArray + embedding_dataset : xr.Dataset or NDArray The dataset containing embeddings, timepoints, fov_name, and track_id, or a numpy array of embeddings. n_components : int, optional @@ -93,10 +93,9 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True): tuple[NDArray, pd.DataFrame] PCA embeddings and PCA DataFrame """ - embeddings = ( embedding_dataset["features"].values - if isinstance(embedding_dataset, Dataset) + if isinstance(embedding_dataset, xr.Dataset) else embedding_dataset ) @@ -110,7 +109,7 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True): pc_features = PCA_features.fit_transform(scaled_features) # Create base dictionary with id and fov_name - if isinstance(embedding_dataset, Dataset): + if isinstance(embedding_dataset, xr.Dataset): pca_dict = { "id": embedding_dataset["id"].values, "fov_name": embedding_dataset["fov_name"].values, @@ -142,13 +141,13 @@ def _fit_transform_umap( def compute_umap( - embedding_dataset: Dataset, normalize_features: bool = True + embedding_dataset: xr.Dataset, normalize_features: bool = True ) -> tuple[umap.UMAP, umap.UMAP, pd.DataFrame]: """Compute UMAP embeddings for features and projections. Parameters ---------- - embedding_dataset : Dataset + embedding_dataset : xr.Dataset Xarray dataset with features and projections. normalize_features : bool, optional Scale the input to zero mean and unit variance before fitting UMAP, diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index a920eb072..fae2e2248 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -2,10 +2,13 @@ from typing import Literal import numpy as np +import xarray as xr from sklearn.metrics.pairwise import cosine_similarity -def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): +def calculate_cosine_similarity_cell( + embedding_dataset: xr.Dataset, fov_name: str, track_id: int +): """Extract embeddings and calculate cosine similarities for a specific cell""" filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) @@ -22,7 +25,7 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): def compute_displacement( - embedding_dataset, + embedding_dataset: xr.Dataset, distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", ) -> dict[int, list[float]]: """Compute the displacement or mean square displacement (MSD) of embeddings. @@ -130,31 +133,38 @@ def compute_displacement_statistics( return mean_displacement_per_tau, std_displacement_per_tau -def compute_dynamic_range(mean_displacement_per_tau): - """ - Compute the dynamic range as the difference between the maximum - and minimum mean displacement per τ. +def compute_dynamic_range(mean_displacement_per_tau: dict[int, float]): + """Compute the dynamic range as the difference between the maximum and minimum mean displacement. - Parameters: - mean_displacement_per_tau: dict with τ as key and mean displacement as value + Per τ. - Returns: - float: dynamic range (max displacement - min displacement) + Parameters + ---------- + mean_displacement_per_tau : dict[int, float] + Dictionary with τ as key and mean displacement as value + + Returns + ------- + float + dynamic range (max displacement - min displacement) """ displacements = list(mean_displacement_per_tau.values()) return max(displacements) - min(displacements) -def compute_rms_per_track(embedding_dataset): +def compute_rms_per_track(embedding_dataset: xr.Dataset): """ Compute RMS of the time derivative of embeddings per track. - Parameters: + Parameters + ---------- embedding_dataset : xarray.Dataset The dataset containing embeddings, timepoints, fov_name, and track_id. - Returns: - list: A list of RMS values, one for each track. + Returns + ------- + list + A list of RMS values, one for each track. """ fov_names = embedding_dataset["fov_name"].values track_ids = embedding_dataset["track_id"].values @@ -193,7 +203,25 @@ def compute_rms_per_track(embedding_dataset): return rms_values -def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): +def calculate_normalized_euclidean_distance_cell( + embedding_dataset: xr.Dataset, fov_name: str, track_id: int +): + """Calculate normalized euclidean distance for a specific cell track. + + Parameters + ---------- + embedding_dataset : xr.Dataset + Dataset containing embedding data with fov_name and track_id coordinates + fov_name : str + Field of view identifier + track_id : int + Track identifier for the specific cell + + Returns + ------- + NDArray + Normalized euclidean distances for the cell track + """ filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) & (embedding_dataset["track_id"] == track_id), diff --git a/viscy/representation/evaluation/feature.py b/viscy/representation/evaluation/feature.py index 4b0896c84..c7ec70647 100644 --- a/viscy/representation/evaluation/feature.py +++ b/viscy/representation/evaluation/feature.py @@ -5,7 +5,7 @@ import pandas as pd import scipy.stats from numpy import fft -from numpy.typing import ArrayLike +from numpy.typing import ArrayLike, NDArray from scipy.ndimage import distance_transform_edt from scipy.stats import linregress from skimage.exposure import rescale_intensity @@ -127,7 +127,7 @@ def __init__(self, image: ArrayLike, segmentation_mask: ArrayLike | None = None) self._eps = 1e-10 - def _compute_kurtosis(self): + def _compute_kurtosis(self) -> float: """Compute the kurtosis of the image. Returns @@ -140,7 +140,7 @@ def _compute_kurtosis(self): return np.nan return scipy.stats.kurtosis(self.image, fisher=True, axis=None) - def _compute_skewness(self): + def _compute_skewness(self) -> float: """Compute the skewness of the image. Returns @@ -153,7 +153,7 @@ def _compute_skewness(self): return np.nan return scipy.stats.skew(self.image, axis=None) - def _compute_glcm_features(self): + def _compute_glcm_features(self) -> tuple[float, float, float]: """Compute GLCM-based texture features from the image. Converts normalized image to uint8 for GLCM computation. @@ -169,7 +169,7 @@ def _compute_glcm_features(self): return contrast, dissimilarity, homogeneity - def _compute_iqr(self): + def _compute_iqr(self) -> float: """Compute the interquartile range of pixel intensities. The IQR is observed to increase when a cell is infected, @@ -184,7 +184,7 @@ def _compute_iqr(self): return iqr - def _compute_weighted_intensity_gradient(self): + def _compute_weighted_intensity_gradient(self) -> float: """Compute the weighted radial intensity gradient profile. Calculates the slope of the azimuthally averaged radial gradient @@ -241,7 +241,7 @@ def _compute_weighted_intensity_gradient(self): return slope - def _compute_spectral_entropy(self): + def _compute_spectral_entropy(self) -> float: """Compute the spectral entropy of the image. Spectral entropy measures the complexity of the image's frequency @@ -268,17 +268,22 @@ def _compute_spectral_entropy(self): return entropy - def _compute_texture_features(self): + def _compute_texture_features(self) -> NDArray: """Compute Haralick texture features from the image. Converts normalized image to uint8 for Haralick computation. + + Returns + ------- + texture_features: NDArray + Haralick texture features of the image. """ # Convert 0-1 normalized image to uint8 (0-255) image_uint8 = (self.image_normalized * 255).astype(np.uint8) texture_features = mh.features.haralick(image_uint8) return np.mean(np.ptp(texture_features, axis=0)) - def _compute_perimeter_area_ratio(self): + def _compute_perimeter_area_ratio(self) -> tuple[float, float, float]: """Compute the perimeter of the nuclear segmentations found inside the patch. This function calculates the average perimeter, average area, and their ratio @@ -286,14 +291,8 @@ def _compute_perimeter_area_ratio(self): Returns ------- - average_perimeter, average_area, ratio: tuple - Tuple containing: - - average_perimeter : float - Average perimeter of all regions in the patch - - average_area : float - Average area of all regions - - ratio : float - Ratio of total perimeter to total area + tuple[float, float, float] + Tuple containing average perimeter, average area, and ratio of total perimeter to total area """ total_perimeter = 0 total_area = 0 @@ -314,7 +313,7 @@ def _compute_perimeter_area_ratio(self): return average_perimeter, average_area, total_perimeter / total_area - def _compute_nucleus_eccentricity(self): + def _compute_nucleus_eccentricity(self) -> float: """Compute the eccentricity of the nucleus. Eccentricity measures how much the nucleus deviates from @@ -336,7 +335,7 @@ def _compute_nucleus_eccentricity(self): eccentricities = [region.eccentricity for region in regions] return float(np.mean(eccentricities)) - def _compute_Eucledian_distance_transform(self): + def _compute_Eucledian_distance_transform(self) -> NDArray: """Compute the Euclidean distance transform of the segmentation mask. This transform computes the distance from each pixel to the @@ -345,7 +344,7 @@ def _compute_Eucledian_distance_transform(self): Returns ------- - dist_transform: ndarray + dist_transform: NDArray Distance transform of the segmentation mask. """ # Ensure the image is binary @@ -376,7 +375,7 @@ def _compute_intensity_localization(self): intensity_weighted_center = np.sum(self.image * edt) / (np.sum(edt) + self._eps) return intensity_weighted_center - def _compute_area(self, sigma=0.6): + def _compute_area(self, sigma: float = 0.6) -> tuple[float, float]: """Create a binary mask using morphological operations. This function creates a binary mask from the input image using Gaussian blur @@ -391,12 +390,8 @@ def _compute_area(self, sigma=0.6): Returns ------- - masked_intensity, masked_area: tuple - Tuple containing: - - masked_intensity : float - Mean intensity inside the sensor area - - masked_area : float - Area of the sensor mask in pixels + tuple[float, float] + Tuple containing masked intensity and masked area """ input_image_blur = gaussian(self.image, sigma=sigma) @@ -411,7 +406,7 @@ def _compute_area(self, sigma=0.6): return masked_intensity, np.sum(mask) - def _compute_zernike_moments(self): + def _compute_zernike_moments(self) -> NDArray: """Compute the Zernike moments of the image. Zernike moments are a set of orthogonal moments that capture @@ -420,16 +415,21 @@ def _compute_zernike_moments(self): Returns ------- - zernike_moments: np.ndarray + zernike_moments: NDArray Zernike moments of the image. """ zernike_moments = mh.features.zernike_moments(self.image, 32) return zernike_moments - def _compute_radial_intensity_gradient(self): + def _compute_radial_intensity_gradient(self) -> float: """Compute the radial intensity gradient of the image. Uses 0-1 normalized image directly for gradient calculation. + + Returns + ------- + radial_intensity_gradient: float + Radial intensity gradient of the image. """ # Use 0-1 normalized image directly y, x = np.indices(self.image_normalized.shape) @@ -447,7 +447,7 @@ def _compute_radial_intensity_gradient(self): return radial_intensity_gradient[0] - def compute_intensity_features(self): + def compute_intensity_features(self) -> IntensityFeatures: """Compute intensity features. This function computes various intensity-based features from the input image. @@ -471,7 +471,7 @@ def compute_intensity_features(self): weighted_intensity_gradient=self._compute_weighted_intensity_gradient(), ) - def compute_texture_features(self): + def compute_texture_features(self) -> TextureFeatures: """Compute texture features. This function computes texture features from the input image. @@ -493,7 +493,7 @@ def compute_texture_features(self): texture=self._compute_texture_features(), ) - def compute_morphology_features(self): + def compute_morphology_features(self) -> MorphologyFeatures: """Compute morphology features. This function computes morphology features from the input image. @@ -528,7 +528,7 @@ def compute_morphology_features(self): masked_area=masked_area, ) - def compute_symmetry_descriptor(self): + def compute_symmetry_descriptor(self) -> SymmetryDescriptor: """Compute the symmetry descriptor of the image. This function computes the symmetry descriptor of the image. @@ -615,20 +615,20 @@ class DynamicFeatures: Parameters ---------- - tracking_df : pandas.DataFrame + tracking_df : pd.DataFrame DataFrame containing cell tracking data with track_id, t, x, y columns Attributes ---------- - tracking_df : pandas.DataFrame + tracking_df : pd.DataFrame The input tracking dataframe containing cell position data over time - track_features : TrackFeatures or None + track_features : TrackFeatures | None Computed velocity-based features including mean, max, min velocities and their standard deviation - displacement_features : DisplacementFeatures or None + displacement_features : DisplacementFeatures | None Computed displacement features including total distance traveled, net displacement, and directional persistence - angular_features : AngularFeatures or None + angular_features : AngularFeatures | None Computed angular features including mean, max, and standard deviation of angular velocities @@ -657,7 +657,7 @@ def __init__(self, tracking_df: pd.DataFrame): if not np.issubdtype(tracking_df[col].dtype, np.number): raise ValueError(f"Column {col} must be numeric") - def _compute_instantaneous_velocity(self, track_id: str) -> np.ndarray: + def _compute_instantaneous_velocity(self, track_id: str) -> NDArray: """Compute the instantaneous velocity for all timepoints in a track. Parameters @@ -667,7 +667,7 @@ def _compute_instantaneous_velocity(self, track_id: str) -> np.ndarray: Returns ------- - velocities : np.ndarray + velocities : NDArray Array of instantaneous velocities for each timepoint """ # Get track data sorted by time @@ -708,15 +708,12 @@ def _compute_displacement(self, track_id: str) -> tuple[float, float, float]: Returns ------- - total_distance, net_displacement, directional_persistence: tuple - Tuple containing: - - total_distance : float - Total distance traveled by the cell along its path - - net_displacement : float - Straight-line distance between start and end positions - - directional_persistence : float - Ratio of net displacement to total distance (0 to 1), - where 1 indicates perfectly straight movement + tuple[float, float, float] + Tuple containing total distance, net displacement, and directional persistence + - total_distance: Total distance traveled by the cell along its path. + - net_displacement: Straight-line distance between start and end positions. + - directional_persistence: Ratio of net displacement to total distance (0 to 1), + where 1 indicates perfectly straight movement. """ track_data = self.tracking_df[ self.tracking_df["track_id"] == track_id @@ -758,11 +755,11 @@ def _compute_angular_velocity(self, track_id: str) -> tuple[float, float, float] Returns ------- - mean_angular_velocity, max_angular_velocity, std_angular_velocity: tuple - Tuple containing: - - mean_angular_velocity - - max_angular_velocity - - std_angular_velocity + tuple[float, float, float] + Tuple containing mean, maximum, and standard deviation of angular velocities + - mean_angular_velocity: Average angular velocity over the track. + - max_angular_velocity: Maximum angular velocity observed in the track. + - std_angular_velocity: Standard deviation of angular velocities in the track. """ track_data = self.tracking_df[ self.tracking_df["track_id"] == track_id diff --git a/viscy/representation/evaluation/lca.py b/viscy/representation/evaluation/lca.py index 7c5216193..89e2f6142 100644 --- a/viscy/representation/evaluation/lca.py +++ b/viscy/representation/evaluation/lca.py @@ -1,23 +1,22 @@ """Linear probing of trained encoder based on cell state labels.""" -from typing import Mapping +from collections.abc import Mapping import pandas as pd import torch import torch.nn as nn +import xarray as xr from captum.attr import IntegratedGradients, Occlusion from numpy.typing import NDArray from sklearn.linear_model import LogisticRegression from sklearn.metrics import classification_report from sklearn.preprocessing import StandardScaler from torch import Tensor -from xarray import DataArray - from viscy.representation.contrastive import ContrastiveEncoder def fit_logistic_regression( - features: DataArray, + features: xr.DataArray, annotations: pd.Series, train_fovs: list[str], remove_background_class: bool = True, @@ -33,7 +32,7 @@ def fit_logistic_regression( Parameters ---------- - features : DataArray + features : xr.DataArray Xarray of features. annotations : pd.Series Categorical class annotations with label values starting from 0. @@ -139,11 +138,37 @@ def __init__(self, backbone: ContrastiveEncoder, classifier: nn.Linear) -> None: @staticmethod def scale_features(x: Tensor) -> Tensor: + """Scale features using standardization. + + Parameters + ---------- + x : Tensor + Input tensor to scale + + Returns + ------- + Tensor + Scaled tensor with zero mean and unit variance + """ m = x.mean(-2, keepdim=True) s = x.std(-2, unbiased=False, keepdim=True) return (x - m) / s def forward(self, x: Tensor, scale_features: bool = False) -> Tensor: + """Forward pass through the LCA backbone. + + Parameters + ---------- + x : Tensor + Input tensor + scale_features : bool, optional + Whether to apply feature scaling, by default False + + Returns + ------- + Tensor + Encoded feature representations + """ x = self.backbone.stem(x) x = self.backbone.encoder(x) if scale_features: diff --git a/viscy/representation/evaluation/visualization.py b/viscy/representation/evaluation/visualization.py index 9d787fe05..f6e433ddc 100644 --- a/viscy/representation/evaluation/visualization.py +++ b/viscy/representation/evaluation/visualization.py @@ -4,6 +4,7 @@ import logging from io import BytesIO from pathlib import Path +from typing import Any import dash import dash.dependencies as dd @@ -12,10 +13,10 @@ import pandas as pd import plotly.graph_objects as go from dash import dcc, html +from numpy.typing import NDArray from PIL import Image from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler - from viscy.data.triplet import TripletDataModule from viscy.representation.embedding_writer import read_embedding_dataset @@ -24,11 +25,18 @@ class EmbeddingVisualizationApp: + """Interactive visualization app for embedding analysis. + + Provides a Dash-based web application for exploring embeddings with PCA + visualization, track selection, and image display capabilities for + representation learning analysis. + """ + def __init__( self, - data_path: str, - tracks_path: str, - features_path: str, + data_path: str | Path, + tracks_path: str | Path, + features_path: str | Path, channels_to_display: list[str] | str, fov_tracks: dict[str, list[int] | str], z_range: tuple[int, int] = (0, 1), @@ -46,11 +54,11 @@ def __init__( Parameters ---------- - data_path: str + data_path: str | Path Path to the data directory. - tracks_path: str + tracks_path: str | Path Path to the tracks directory. - features_path: str + features_path: str | Path Path to the features directory. channels_to_display: list[str] | str List of channels to display. @@ -68,6 +76,7 @@ def __init__( Number of workers to use for loading data. output_dir: str | None, optional Directory to save CSV files and other outputs. If None, uses current working directory. + Returns ------- None @@ -101,7 +110,7 @@ def __init__( self._init_app() atexit.register(self._cleanup_cache) - def _prepare_data(self): + def _prepare_data(self) -> None: """Prepare the feature data and PCA transformation""" embedding_dataset = read_embedding_dataset(self.features_path) features = embedding_dataset["features"] @@ -182,11 +191,11 @@ def _prepare_data(self): # Combine all filtered features self.filtered_features_df = pd.concat(all_filtered_features, axis=0) - def _create_figure(self): + def _create_figure(self) -> None: """Create the initial scatter plot figure""" self.fig = self._create_track_colored_figure() - def _init_app(self): + def _init_app(self) -> None: """Initialize the Dash application""" self.app = dash.Dash(__name__) @@ -509,14 +518,29 @@ def _init_app(self): prevent_initial_call=True, ) def update_figure( - color_mode, - show_arrows, - x_axis, - y_axis, - relayout_data, - selected_data, - current_figure, - ): + color_mode: str, + show_arrows: list[str] | None, + x_axis: str, + y_axis: str, + relayout_data: dict[str, Any] | None, + selected_data: dict[str, Any] | None, + current_figure: dict[str, Any], + ) -> tuple[dict[str, Any], dict[str, Any] | None]: + """Update the figure based on the selected data. + + Parameters + ---------- + color_mode: str + The color mode. + show_arrows: list[str] | None + The show arrows. + x_axis: str + The x axis. + y_axis: str + The y axis. + """ + if show_arrows is None: + show_arrows = [] show_arrows = len(show_arrows or []) > 0 ctx = dash.callback_context @@ -554,8 +578,19 @@ def update_figure( [dd.Input("scatter-plot", "clickData")], prevent_initial_call=True, ) - def update_track_timeline(clickData): - """Update the track timeline based on the clicked point""" + def update_track_timeline(clickData: dict[str, Any] | None) -> html.Div: + """Update the track timeline based on the clicked point + + Parameters + ---------- + clickData: dict[str, Any] | None + The click data from the scatter plot. + + Returns + ------- + html.Div: The track timeline. + + """ if clickData is None: return html.Div("Click on a point to see the track timeline") @@ -727,19 +762,61 @@ def update_track_timeline(clickData): prevent_initial_call=True, ) def update_clusters_tab( - assign_clicks, - clear_clicks, - save_name_clicks, - cancel_name_clicks, - edit_name_clicks, - selected_data, - current_figure, - color_mode, - show_arrows, - x_axis, - y_axis, - cluster_name, - ): + assign_clicks: int | None, + clear_clicks: int | None, + save_name_clicks: int | None, + cancel_name_clicks: int | None, + edit_name_clicks: list[int], + selected_data: dict[str, Any] | None, + current_figure: dict[str, Any], + color_mode: str, + show_arrows: list[str] | None, + x_axis: str, + y_axis: str, + cluster_name: str | None, + ) -> tuple[ + dict[str, str], + html.Div | None, + str, + dict[str, Any] | Any, + dict[str, str], + str, + dict[str, Any] | None, + ]: + """Update the clusters tab and handle modal. + + Parameters + ---------- + assign_clicks: int | None + The number of clicks on the assign cluster button. + clear_clicks: int | None + The number of clicks on the clear clusters button. + save_name_clicks: int | None + The number of clicks on the save cluster name button. + cancel_name_clicks: int | None + The number of clicks on the cancel cluster name button. + edit_name_clicks: list[int] + The indices of the edit cluster name buttons. + selected_data: dict[str, Any] | None + The selected data from the scatter plot. + current_figure: dict[str, Any] + The current figure. + color_mode: str + The color mode. + show_arrows: list[str] | None + The show arrows. + x_axis: str + The x axis. + y_axis: str + The y axis. + cluster_name: str | None + The cluster name. + + Returns + ------- + tuple[dict[str, str], html.Div | None, str, dict[str, Any] | Any, dict[str, str], str, dict[str, Any] | None]: + The updated clusters tab and handle modal. + """ ctx = dash.callback_context if not ctx.triggered: return ( @@ -962,8 +1039,18 @@ def update_clusters_tab( [dd.Input("save-clusters-csv", "n_clicks")], prevent_initial_call=True, ) - def save_clusters_csv(n_clicks): - """Callback to save clusters to CSV file""" + def save_clusters_csv(n_clicks: int | None) -> html.Div: + """Callback to save clusters to CSV file + + Parameters + ---------- + n_clicks: int | None + The number of clicks on the save clusters CSV button. + + Returns + ------- + html.Div: The cluster container. + """ if n_clicks and self.clusters: try: output_path = self.save_clusters_to_csv() @@ -1035,8 +1122,33 @@ def save_clusters_csv(n_clicks): ], prevent_initial_call=True, ) - def clear_selection(n_clicks, color_mode, show_arrows, x_axis, y_axis): - """Callback to clear the selection and restore original opacity""" + def clear_selection( + n_clicks: int | None, + color_mode: str, + show_arrows: list[str] | None, + x_axis: str, + y_axis: str, + ) -> tuple[dict[str, Any] | Any, dict[str, Any] | None]: + """Callback to clear the selection and restore original opacity + + Parameters + ---------- + n_clicks: int | None + The number of clicks on the clear selection button. + color_mode: str + The color mode. + show_arrows: list[str] | None + The show arrows. + x_axis: str + The x axis. + y_axis: str + The y axis. + + Returns + ------- + tuple[dict[str, Any] | Any, dict[str, Any] | None]: + The new figure and clear selectedData. + """ if n_clicks: # Create a new figure with no selections if color_mode == "track": @@ -1063,7 +1175,9 @@ def clear_selection(n_clicks, color_mode, show_arrows, x_axis, y_axis): return fig, None # Return new figure and clear selectedData return dash.no_update, dash.no_update - def _calculate_equal_aspect_ranges(self, x_data, y_data): + def _calculate_equal_aspect_ranges( + self, x_data: NDArray, y_data: NDArray + ) -> tuple[tuple[float, float], tuple[float, float]]: """Calculate ranges for x and y axes to ensure equal aspect ratio. Parameters @@ -1110,11 +1224,26 @@ def _calculate_equal_aspect_ranges(self, x_data, y_data): def _create_track_colored_figure( self, - show_arrows=False, - x_axis=None, - y_axis=None, - ): - """Create scatter plot with track-based coloring""" + show_arrows: bool = False, + x_axis: str | None = None, + y_axis: str | None = None, + ) -> go.Figure: + """Create scatter plot with track-based coloring + + Parameters + ---------- + show_arrows: bool + The show arrows. + x_axis: str | None + The x axis. + y_axis: str | None + The y axis. + + Returns + ------- + go.Figure + The scatter plot. + """ x_axis = x_axis or self.default_x y_axis = y_axis or self.default_y @@ -1329,10 +1458,10 @@ def _create_track_colored_figure( def _create_time_colored_figure( self, - show_arrows=False, - x_axis=None, - y_axis=None, - ): + show_arrows: bool = False, + x_axis: str | None = None, + y_axis: str | None = None, + ) -> go.Figure: """Create scatter plot with time-based coloring""" x_axis = x_axis or self.default_x y_axis = y_axis or self.default_y @@ -1481,7 +1610,7 @@ def _create_time_colored_figure( return fig @staticmethod - def _normalize_image(img_array): + def _normalize_image(img_array: NDArray) -> NDArray: """Normalize a single image array to [0, 255] more efficiently""" min_val = img_array.min() max_val = img_array.max() @@ -1491,7 +1620,7 @@ def _normalize_image(img_array): return ((img_array - min_val) * 255 / (max_val - min_val)).astype(np.uint8) @staticmethod - def _numpy_to_base64(img_array): + def _numpy_to_base64(img_array: NDArray) -> str: """Convert numpy array to base64 string with compression""" if not isinstance(img_array, np.uint8): img_array = img_array.astype(np.uint8) @@ -1503,12 +1632,12 @@ def _numpy_to_base64(img_array): "utf-8" ) - def save_cache(self, cache_path: str | None = None): + def save_cache(self, cache_path: str |Path | None = None) -> None: """Save the image cache to disk using pickle. Parameters ---------- - cache_path : str | None, optional + cache_path : str | Path | None, optional Path to save the cache. If None, uses self.cache_path, by default None """ import pickle @@ -1543,12 +1672,12 @@ def save_cache(self, cache_path: str | None = None): except Exception as e: logger.error(f"Error saving cache: {e}") - def load_cache(self, cache_path: str | None = None) -> bool: + def load_cache(self, cache_path: str | Path | None = None) -> bool: """Load the image cache from disk using pickle. Parameters ---------- - cache_path : str | None, optional + cache_path : str | Path | None, optional Path to load the cache from. If None, uses self.cache_path, by default None Returns @@ -1596,7 +1725,7 @@ def load_cache(self, cache_path: str | None = None) -> bool: logger.error(f"Error loading cache: {e}") return False - def preload_images(self): + def preload_images(self) -> None: """Preload all images into memory""" # Try to load from cache first if self.cache_path and self.load_cache(): @@ -1625,7 +1754,7 @@ def preload_images(self): final_yx_patch_size=self.yx_patch_size, batch_size=1, num_workers=self.num_loading_workers, - normalizations=None, + normalizations=[], predict_cells=True, ) data_module.setup("predict") @@ -1696,12 +1825,14 @@ def preload_images(self): if self.cache_path: self.save_cache() - def _cleanup_cache(self): + def _cleanup_cache(self) -> None: """Clear the image cache when the program exits""" logging.info("Cleaning up image cache...") self.image_cache.clear() - def _get_trajectory_images_lasso(self, x_axis, y_axis, selected_data): + def _get_trajectory_images_lasso( + self, x_axis: str, y_axis: str, selected_data: dict[str, Any] | None + ) -> html.Div: """Get images of points selected by lasso""" if not selected_data or not selected_data.get("points"): return html.Div("Use the lasso tool to select points") @@ -1908,7 +2039,7 @@ def _get_output_info_display(self) -> html.Div: }, ) - def _get_cluster_images(self): + def _get_cluster_images(self) -> html.Div: """Display images for all clusters in a grid layout""" if not self.clusters: return html.Div( @@ -2117,7 +2248,7 @@ def get_output_dir(self) -> Path: """ return self.output_dir - def save_clusters_to_csv(self, output_path: str | None = None) -> str: + def save_clusters_to_csv(self, output_path: str | Path | None = None) -> str: """ Save cluster information to CSV file. @@ -2126,7 +2257,7 @@ def save_clusters_to_csv(self, output_path: str | None = None) -> str: Parameters ---------- - output_path : str | None, optional + output_path : str | Path | None, optional Path to save the CSV file. If None, generates a timestamped filename in the output directory, by default None @@ -2195,7 +2326,7 @@ def save_clusters_to_csv(self, output_path: str | None = None) -> str: logger.error(f"Error saving clusters to CSV: {e}") raise - def run(self, debug=False, port=None): + def run(self, debug: bool = False, port: int | None = None) -> None: """Run the Dash server Parameters @@ -2207,12 +2338,12 @@ def run(self, debug=False, port=None): """ import socket - def is_port_in_use(port): + def is_port_in_use(port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.bind(("127.0.0.1", port)) return False - except socket.error: + except OSError: return True if port is None: diff --git a/viscy/representation/multi_modal.py b/viscy/representation/multi_modal.py index 55481d434..c0f1cfe94 100644 --- a/viscy/representation/multi_modal.py +++ b/viscy/representation/multi_modal.py @@ -1,10 +1,10 @@ +from collections.abc import Sequence from logging import getLogger -from typing import Literal, Sequence +from typing import Literal import torch from pytorch_metric_learning.losses import NTXentLoss from torch import Tensor, nn - from viscy.data.typing import TripletSample from viscy.representation.contrastive import ContrastiveEncoder from viscy.representation.engine import ContrastiveModule @@ -13,6 +13,20 @@ class JointEncoders(nn.Module): + """Joint multi-modal encoders for cross-modal representation learning. + + Pairs source and target encoders for CLIP-style contrastive learning + across different modalities or channels. Enables cross-modal alignment + and similarity computation through joint feature extraction. + + Parameters + ---------- + source_encoder : nn.Module | ContrastiveEncoder + Encoder for source modality/channel data. + target_encoder : nn.Module | ContrastiveEncoder + Encoder for target modality/channel data. + """ + def __init__( self, source_encoder: nn.Module | ContrastiveEncoder, @@ -25,14 +39,59 @@ def __init__( def forward( self, source: Tensor, target: Tensor ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: + """Forward pass through both encoders for multi-modal features. + + Parameters + ---------- + source : Tensor + Source modality input tensor. + target : Tensor + Target modality input tensor. + + Returns + ------- + tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]] + Tuple of (source_features, source_projections) and + (target_features, target_projections) for cross-modal learning. + """ return self.source_encoder(source), self.target_encoder(target) def forward_features(self, source: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: + """Extract feature representations from both modalities. + + Parameters + ---------- + source : Tensor + Source modality input tensor. + target : Tensor + Target modality input tensor. + + Returns + ------- + tuple[Tensor, Tensor] + Feature representations from source and target encoders for + multi-modal representation learning. + """ return self.source_encoder(source)[0], self.target_encoder(target)[0] def forward_projections( self, source: Tensor, target: Tensor ) -> tuple[Tensor, Tensor]: + """Extract projection representations for contrastive learning. + + Parameters + ---------- + source : Tensor + Source modality input tensor. + target : Tensor + Target modality input tensor. + + Returns + ------- + tuple[Tensor, Tensor] + Projection representations from source and target encoders for + cross-modal contrastive alignment and similarity computation. + """ return self.source_encoder(source)[1], self.target_encoder(target)[1] @@ -67,6 +126,21 @@ def __init__( self._prediction_arm = prediction_arm def forward(self, source: Tensor, target: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass for cross-modal contrastive projections. + + Parameters + ---------- + source : Tensor + Source modality input tensor. + target : Tensor + Target modality input tensor. + + Returns + ------- + tuple[Tensor, Tensor] + Projection tensors from source and target encoders for + cross-modal contrastive learning and alignment. + """ return self.model.forward_projections(source, target) def _info_nce_style_loss(self, z1: Tensor, z2: Tensor) -> Tensor: @@ -110,12 +184,53 @@ def _fit_forward_step( return loss def training_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Training step for cross-modal contrastive learning. + + Parameters + ---------- + batch : TripletSample + Batch containing anchor and positive samples for multi-modal + contrastive learning. + batch_idx : int + Batch index in current epoch. + + Returns + ------- + Tensor + Cross-modal contrastive loss for training optimization. + """ return self._fit_forward_step(batch=batch, batch_idx=batch_idx, stage="train") def validation_step(self, batch: TripletSample, batch_idx: int) -> Tensor: + """Validation step for cross-modal contrastive learning. + + Parameters + ---------- + batch : TripletSample + Batch containing anchor and positive samples for multi-modal + validation. + batch_idx : int + Batch index in current validation epoch. + + Returns + ------- + Tensor + Cross-modal contrastive loss for validation monitoring. + """ return self._fit_forward_step(batch=batch, batch_idx=batch_idx, stage="val") def on_predict_start(self) -> None: + """Configure prediction encoder arm for multi-modal inference. + + Sets up the appropriate encoder (source or target) and channel slice + based on the prediction_arm configuration for single-modality + inference from the trained cross-modal model. + + Raises + ------ + ValueError + If prediction_arm is not 'source' or 'target'. + """ _logger.info(f"Using {self._prediction_arm} encoder for predictions.") if self._prediction_arm == "source": self._prediction_encoder = self.model.source_encoder @@ -129,6 +244,27 @@ def on_predict_start(self) -> None: def predict_step( self, batch: TripletSample, batch_idx: int, dataloader_idx: int = 0 ): + """Prediction step using selected encoder arm. + + Extracts features and projections using the configured prediction + encoder (source or target) for single-modality inference from the + trained cross-modal model. + + Parameters + ---------- + batch : TripletSample + Batch containing anchor samples for prediction. + batch_idx : int + Batch index in current prediction run. + dataloader_idx : int, default=0 + Index of dataloader when using multiple prediction dataloaders. + + Returns + ------- + dict + Dictionary containing 'features', 'projections', and 'index' + for the predicted samples from the selected modality encoder. + """ features, projections = self._prediction_encoder( batch["anchor"][:, self._prediction_channel_slice] ) diff --git a/viscy/trainer.py b/viscy/trainer.py index 03395a371..5f12db396 100644 --- a/viscy/trainer.py +++ b/viscy/trainer.py @@ -15,6 +15,12 @@ class VisCyTrainer(Trainer): + """Extended Lightning Trainer for VisCy with preprocessing and export capabilities. + + Provides additional functionality for dataset preprocessing, model export, + and normalization metadata computation for computer vision training workflows. + """ + def preprocess( self, data_path: Path, @@ -118,6 +124,29 @@ def precompute( exclude_fovs: list[str] | None = None, model: LightningModule | None = None, ): + """Precompute and normalize image arrays for efficient training. + + Parameters + ---------- + data_path : Path + Path to input HCS OME-Zarr dataset + output_path : Path + Path to save precomputed arrays + channel_names : list[str] + List of channel names to process + subtrahends : list[Literal["mean"] | float] + Subtraction values for normalization (per channel) + divisors : list[Literal["std"] | tuple[float, float]] + Division values for normalization (per channel) + image_array_key : str, optional + Array key in OME-Zarr structure, by default "0" + include_wells : list[str] | None, optional + Wells to include, by default None + exclude_fovs : list[str] | None, optional + Fields of view to exclude, by default None + model : LightningModule | None, optional + Ignored placeholder parameter, by default None + """ precompute_array( data_path=data_path, output_path=output_path, diff --git a/viscy/transforms/_gaussian_blur.py b/viscy/transforms/_gaussian_blur.py index 85292e408..1f32d9a99 100644 --- a/viscy/transforms/_gaussian_blur.py +++ b/viscy/transforms/_gaussian_blur.py @@ -1,6 +1,7 @@ """3D version of `kornia.augmentation._2d.intensity.gaussian_blur`.""" -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any from kornia.augmentation import random_generator as rg from kornia.augmentation._3d.intensity.base import IntensityAugmentationBase3D diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index 696c81abc..9171ffd03 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -1,6 +1,7 @@ """Redefine transforms from MONAI for jsonargparse.""" -from typing import Sequence +from collections.abc import Sequence +from typing import Any from monai.transforms import ( CenterSpatialCropd, @@ -26,8 +27,8 @@ def __init__( detach: bool = True, pad_batch: bool = True, fill_value: float | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, detach=detach, @@ -38,7 +39,7 @@ def __init__( class ToDeviced(ToDeviced): - def __init__(self, keys: Sequence[str] | str, **kwargs): + def __init__(self, keys: Sequence[str] | str, **kwargs: Any) -> None: super().__init__(keys=keys, **kwargs) @@ -49,8 +50,8 @@ def __init__( w_key: str, spatial_size: Sequence[int], num_samples: int = 1, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, w_key=w_key, @@ -68,8 +69,8 @@ def __init__( rotate_range: Sequence[float | Sequence[float]] | float, shear_range: Sequence[float | Sequence[float]] | float, scale_range: Sequence[float | Sequence[float]] | float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, prob=prob, @@ -86,8 +87,8 @@ def __init__( keys: Sequence[str] | str, prob: float, gamma: tuple[float, float] | float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, prob=prob, gamma=gamma, **kwargs) @@ -97,8 +98,8 @@ def __init__( keys: Sequence[str] | str, factors: tuple[float, float] | float, prob: float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, factors=factors, prob=prob, **kwargs) @@ -109,8 +110,8 @@ def __init__( prob: float, mean: float, std: float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, prob=prob, mean=mean, std=std, **kwargs) @@ -122,8 +123,8 @@ def __init__( sigma_x: tuple[float, float] | float, sigma_y: tuple[float, float] | float, sigma_z: tuple[float, float] | float, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, prob=prob, @@ -147,7 +148,7 @@ def __init__( channel_wise: bool = False, dtype: DTypeLike | None = None, allow_missing_keys: bool = False, - ): + ) -> None: super().__init__( keys=keys, lower=lower, @@ -168,8 +169,8 @@ def __init__( keys: Sequence[str] | str, roi_size: Sequence[int] | int, random_center: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__( keys=keys, roi_size=roi_size, @@ -183,8 +184,8 @@ def __init__( self, keys: Sequence[str] | str, roi_size: Sequence[int] | int, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, roi_size=roi_size, **kwargs) @@ -194,6 +195,6 @@ def __init__( keys: Sequence[str] | str, prob: float, spatial_axis: Sequence[int] | int, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(keys=keys, prob=prob, spatial_axis=spatial_axis, **kwargs) diff --git a/viscy/transforms/_transforms.py b/viscy/transforms/_transforms.py index c9418f1b6..d6b68e6b0 100644 --- a/viscy/transforms/_transforms.py +++ b/viscy/transforms/_transforms.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable, Sequence +from typing import Literal from warnings import warn import numpy as np @@ -14,7 +16,6 @@ ) from numpy.typing import DTypeLike from torch import Tensor -from typing_extensions import Iterable, Literal, Sequence from viscy.data.typing import ChannelMap, Sample @@ -75,9 +76,7 @@ def _normalize(): class RandInvertIntensityd(MapTransform, RandomizableTransform): - """ - Randomly invert the intensity of the image. - """ + """Randomly invert the intensity of the image.""" def __init__( self, @@ -99,8 +98,8 @@ def __call__(self, sample: Sample) -> Sample: class TiledSpatialCropSamplesd(MapTransform, MultiSampleTrait): - """ - Crop multiple tiled ROIs from an image. + """Crop multiple tiled ROIs from an image. + Used for deterministic cropping in validation. """ @@ -166,7 +165,7 @@ def __call__(self, sample: Sample) -> Sample: class BatchedZoom(Transform): - "Batched zoom transform using ``torch.nn.functional.interpolate``." + """Batched zoom transform using ``torch.nn.functional.interpolate``.""" def __init__( self, diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 56af9b985..22e6f3e40 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -1,7 +1,8 @@ import logging import os import random -from typing import Callable, Literal, Sequence, Union +from collections.abc import Callable, Sequence +from typing import Any, Literal, Union import numpy as np import torch @@ -24,7 +25,6 @@ structural_similarity_index_measure, ) from torchmetrics.functional.segmentation import dice_score - from viscy.data.combined import CombinedDataModule from viscy.data.gpu_aug import GPUTransformDataModule from viscy.data.typing import Sample @@ -48,17 +48,23 @@ class MixedLoss(nn.Module): """Mixed reconstruction loss. + Adapted from Zhao et al, https://arxiv.org/pdf/1511.08861.pdf Reduces to simple distances if only one weight is non-zero. - :param float l1_alpha: L1 loss weight, defaults to 0.5 - :param float l2_alpha: L2 loss weight, defaults to 0.0 - :param float ms_dssim_alpha: MS-DSSIM weight, defaults to 0.5 + Parameters + ---------- + l1_alpha : float, optional + L1 loss weight, by default 0.5 + l2_alpha : float, optional + L2 loss weight, by default 0.0 + ms_dssim_alpha : float, optional + MS-DSSIM weight, by default 0.5 """ def __init__( self, l1_alpha: float = 0.5, l2_alpha: float = 0.0, ms_dssim_alpha: float = 0.5 - ): + ) -> None: super().__init__() if not any([l1_alpha, l2_alpha, ms_dssim_alpha]): raise ValueError("Loss term weights cannot be all zero!") @@ -67,7 +73,21 @@ def __init__( self.ms_dssim_alpha = ms_dssim_alpha @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) - def forward(self, preds, target): + def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Compute mixed reconstruction loss. + + Parameters + ---------- + preds : torch.Tensor + Predicted tensor + target : torch.Tensor + Target tensor + + Returns + ------- + torch.Tensor + Combined loss value + """ loss = 0 if self.l1_alpha: # the gaussian in the reference is not used @@ -84,7 +104,30 @@ def forward(self, preds, target): class MaskedMSELoss(nn.Module): - def forward(self, preds, original, mask): + """Masked mean squared error loss. + + Computes MSE loss only for masked regions. + """ + + def forward( + self, preds: torch.Tensor, original: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + """Compute masked MSE loss. + + Parameters + ---------- + preds : torch.Tensor + Predicted tensor. + original : torch.Tensor + Original tensor. + mask : torch.Tensor + Binary mask tensor. + + Returns + ------- + torch.Tensor + Masked MSE loss value. + """ loss = F.mse_loss(preds, original, reduction="none") loss = (loss.mean(2) * mask).sum() / mask.sum() return loss @@ -93,44 +136,52 @@ def forward(self, preds, original, mask): class VSUNet(LightningModule): """Regression U-Net module for virtual staining. - :param dict model_config: model config, - defaults to :py:class:`viscy.unet.utils.model.ModelDefaults25D` - :param Union[nn.Module, MixedLoss] loss_function: - loss function in training/validation, - if a dictionary, should specify weights of each term - ('l1_alpha', 'l2_alpha', 'ssim_alpha') - defaults to L2 (mean squared error) - :param float lr: learning rate in training, defaults to 1e-3 - :param Literal['WarmupCosine', 'Constant'] schedule: - learning rate scheduler, defaults to "Constant" - :param str ckpt_path: path to the checkpoint to load weights, defaults to None - :param int log_batches_per_epoch: - number of batches to log each training/validation epoch, - has to be smaller than steps per epoch, defaults to 8 - :param int log_samples_per_batch: - number of samples to log each training/validation batch, - has to be smaller than batch size, defaults to 1 - :param Sequence[int] example_input_yx_shape: - XY shape of the example input for network graph tracing, defaults to (256, 256) - :param str test_cellpose_model_path: - path to the CellPose model for testing segmentation, defaults to None - :param float test_cellpose_diameter: - diameter parameter of the CellPose model for testing segmentation, - defaults to None - :param bool test_evaluate_cellpose: - evaluate the performance of the CellPose model instead of the trained model - in test stage, defaults to False - :param bool test_time_augmentations: - apply test time augmentations in test stage, defaults to False - :param Literal['mean', 'median', 'product'] tta_type: - type of test time augmentations aggregation, defaults to "mean" + Parameters + ---------- + architecture : Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae", "UNeXt2_2D"] + Model architecture type. + model_config : dict, optional + Model config, defaults to :py:class:`viscy.unet.utils.model.ModelDefaults25D`, + by default {}. + loss_function : Union[nn.Module, MixedLoss], optional + Loss function in training/validation. If a dictionary, should specify weights + of each term ('l1_alpha', 'l2_alpha', 'ssim_alpha'), defaults to L2 + (mean squared error), by default None. + lr : float, optional + Learning rate in training, by default 1e-3. + schedule : Literal['WarmupCosine', 'Constant'], optional + Learning rate scheduler, by default "Constant". + freeze_encoder : bool, optional + Whether to freeze encoder weights, by default False. + ckpt_path : str, optional + Path to the checkpoint to load weights, by default None. + log_batches_per_epoch : int, optional + Number of batches to log each training/validation epoch, + has to be smaller than steps per epoch, by default 8. + log_samples_per_batch : int, optional + Number of samples to log each training/validation batch, + has to be smaller than batch size, by default 1. + example_input_yx_shape : Sequence[int], optional + XY shape of the example input for network graph tracing, by default (256, 256). + test_cellpose_model_path : str, optional + Path to the CellPose model for testing segmentation, by default None. + test_cellpose_diameter : float, optional + Diameter parameter of the CellPose model for testing segmentation, + by default None. + test_evaluate_cellpose : bool, optional + Evaluate the performance of the CellPose model instead of the trained model + in test stage, by default False. + test_time_augmentations : bool, optional + Apply test time augmentations in test stage, by default False. + tta_type : Literal['mean', 'median', 'product'], optional + Type of test time augmentations aggregation, by default "mean". """ def __init__( self, architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae", "UNeXt2_2D"], model_config: dict = {}, - loss_function: Union[nn.Module, MixedLoss] | None = None, + loss_function: nn.Module | MixedLoss | None = None, lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", freeze_encoder: bool = False, @@ -185,9 +236,37 @@ def __init__( ) # loading only weights def forward(self, x: Tensor) -> Tensor: + """Forward pass through the model. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Model output. + """ return self.model(x) - def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int): + def training_step( + self, batch: Sample | Sequence[Sample], batch_idx: int + ) -> torch.Tensor: + """Execute single training step. + + Parameters + ---------- + batch : Sample or Sequence[Sample] + Training batch data. + batch_idx : int + Batch index. + + Returns + ------- + torch.Tensor + Training loss. + """ losses = [] batch_size = 0 if not isinstance(batch, Sequence): @@ -216,7 +295,20 @@ def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int): ) return loss_step - def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + def validation_step( + self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 + ) -> None: + """Execute single validation step. + + Parameters + ---------- + batch : Sample + Validation batch data. + batch_idx : int + Batch index. + dataloader_idx : int, default=0 + Dataloader index for multi-dataloader validation. + """ source: Tensor = batch["source"] target: Tensor = batch["target"] pred = self.forward(source) @@ -235,7 +327,16 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 detach_sample((source, target, pred), self.log_samples_per_batch) ) - def test_step(self, batch: Sample, batch_idx: int): + def test_step(self, batch: Sample, batch_idx: int) -> None: + """Execute single test step. + + Parameters + ---------- + batch : Sample + Test batch data. + batch_idx : int + Batch index. + """ source = batch["source"] target = batch["target"] center_index = target.shape[-3] // 2 @@ -266,7 +367,7 @@ def test_step(self, batch: Sample, batch_idx: int): else: self._log_segmentation_metrics(None, None) - def _log_regression_metrics(self, pred: Tensor, target: Tensor): + def _log_regression_metrics(self, pred: Tensor, target: Tensor) -> None: # paired image translation metrics self.log_dict( { @@ -298,7 +399,7 @@ def _cellpose_predict(self, pred: Tensor, name: str) -> torch.ShortTensor: def _log_segmentation_metrics( self, pred_labels: torch.ShortTensor, target_labels: torch.ShortTensor - ): + ) -> None: compute = pred_labels is not None if compute: pred_binary = pred_labels > 0 @@ -337,7 +438,25 @@ def _log_segmentation_metrics( on_epoch=False, ) - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + def predict_step( + self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 + ) -> dict[str, Any]: + """Execute single prediction step. + + Parameters + ---------- + batch : Sample + Prediction batch data. + batch_idx : int + Batch index. + dataloader_idx : int, default=0 + Dataloader index. + + Returns + ------- + torch.Tensor + Model prediction. + """ source = batch["source"] if self.test_time_augmentations: prediction = self.perform_test_time_augmentations(source) @@ -349,13 +468,21 @@ def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): return prediction def perform_test_time_augmentations(self, source: Tensor) -> Tensor: - """Perform test time augmentations on the input source - and aggregate the predictions using the specified method. + """Perform test time augmentations and aggregate predictions. - :param source: input tensor - :return: aggregated prediction - """ + Apply rotational augmentations to input source and aggregate the + predictions using the specified method. + + Parameters + ---------- + source : torch.Tensor + Input tensor. + Returns + ------- + torch.Tensor + Aggregated prediction. + """ # Save the yx coords to crop post rotations self._original_shape_yx = source.shape[-2:] predictions = [] @@ -384,11 +511,13 @@ def perform_test_time_augmentations(self, source: Tensor) -> Tensor: prediction = torch.exp(log_prediction_sum) return prediction - def on_train_epoch_end(self): + def on_train_epoch_end(self) -> None: + """Log training samples at end of epoch.""" self._log_samples("train_samples", self.training_step_outputs) self.training_step_outputs = [] - def on_validation_epoch_end(self): + def on_validation_epoch_end(self) -> None: + """Log validation samples and compute average loss at end of epoch.""" super().on_validation_epoch_end() self._log_samples("val_samples", self.validation_step_outputs) # average within each dataloader @@ -401,7 +530,7 @@ def on_validation_epoch_end(self): self.validation_step_outputs.clear() self.validation_losses.clear() - def on_test_start(self): + def on_test_start(self) -> None: """Load CellPose model for segmentation.""" if self.test_cellpose_model_path is not None: try: @@ -417,14 +546,23 @@ def on_test_start(self): '`pip install viscy"[metrics]"`' ) - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. + def on_predict_start(self) -> None: + """Setup prediction padding transform. + + Pad the input shape to be divisible by the downsampling factor. The inverse of this transform crops the prediction to original shape. """ down_factor = 2**self.model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - def configure_optimizers(self): + def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[Any]]: + """Configure optimizer and learning rate scheduler. + + Returns + ------- + tuple + Tuple containing optimizer and scheduler lists. + """ if self.freeze_encoder: self.model: FullyConvolutionalMAE self.model.encoder.requires_grad_(False) @@ -442,7 +580,7 @@ def configure_optimizers(self): ) return [optimizer], [scheduler] - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]) -> None: grid = render_images(imgs) self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" @@ -477,8 +615,7 @@ def _crop_to_original(self, tensor: Tensor) -> Tensor: class AugmentedPredictionVSUNet(LightningModule): - """Apply arbitrary collection of test-time augmentations - for image translation prediction. + """Apply arbitrary collection of test-time augmentations for image translation prediction. Parameters ---------- @@ -528,15 +665,51 @@ def __init__( self._reduction = reduction def forward(self, x: Tensor) -> Tensor: + """Forward pass through the model. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Model output. + """ return self.model(x) def setup(self, stage: str) -> None: + """Setup method for Lightning module. + + Parameters + ---------- + stage : str + Stage name (only 'predict' is supported). + + Raises + ------ + NotImplementedError + If stage is not 'predict'. + """ if stage != "predict": raise NotImplementedError( f"Only the 'predict' stage is supported by {type(self)}" ) def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: + """Reduce multiple predictions using specified method. + + Parameters + ---------- + preds : list[torch.Tensor] + List of prediction tensors. + + Returns + ------- + torch.Tensor + Reduced prediction tensor. + """ prediction = torch.stack(preds, dim=0) if self._reduction == "mean": prediction = prediction.mean(dim=0) @@ -547,6 +720,22 @@ def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: def predict_step( self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: + """Execute single prediction step with augmentations. + + Parameters + ---------- + batch : Sample + Prediction batch data. + batch_idx : int + Batch index. + dataloader_idx : int, default=0 + Dataloader index. + + Returns + ------- + torch.Tensor + Aggregated prediction from augmented inputs. + """ source = batch["source"] preds = [] for forward_t, inverse_t in zip( @@ -566,16 +755,36 @@ def predict_step( class FcmaeUNet(VSUNet): + """Fully Convolutional Masked Autoencoder U-Net. + + Extends VSUNet to support masked autoencoder pre-training and supervised + fine-tuning for virtual staining tasks. + + Parameters + ---------- + fit_mask_ratio : float, default=0.0 + Masking ratio for FCMAE pre-training. + **kwargs + Additional arguments passed to VSUNet. + """ + def __init__( self, fit_mask_ratio: float = 0.0, **kwargs, - ): + ) -> None: super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio self.save_hyperparameters(ignore=["loss_function"]) - def on_fit_start(self): + def on_fit_start(self) -> None: + """Setup data modules and validate configuration for training. + + Raises + ------ + ValueError + If data module configuration is incompatible with FCMAE training. + """ dm = self.trainer.datamodule if not isinstance(dm, CombinedDataModule): raise ValueError( @@ -595,12 +804,42 @@ def on_fit_start(self): f"got {type(self.loss_function)}" ) - def forward(self, x: Tensor, mask_ratio: float = 0.0): + def forward( + self, x: Tensor, mask_ratio: float = 0.0 + ) -> tuple[Tensor, Tensor] | Tensor: + """Forward pass with optional masking. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + mask_ratio : float, default=0.0 + Masking ratio for FCMAE mode. + + Returns + ------- + torch.Tensor or tuple + Model output, optionally with mask if mask_ratio > 0. + """ return self.model(x, mask_ratio) def forward_fit_fcmae( self, batch: Sample, return_target: bool = False ) -> tuple[Tensor, Tensor | None, Tensor]: + """Forward pass for FCMAE pre-training. + + Parameters + ---------- + batch : Sample + Input batch. + return_target : bool, default=False + Whether to return masked target for logging. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor or None, torch.Tensor] + Prediction, target (if requested), and loss. + """ x = batch["source"] pred, mask = self.forward(x, mask_ratio=self.fit_mask_ratio) loss = self.loss_function(pred, x, mask) @@ -611,6 +850,18 @@ def forward_fit_fcmae( return pred, target, loss def forward_fit_supervised(self, batch: Sample) -> tuple[Tensor, Tensor, Tensor]: + """Forward pass for supervised training. + + Parameters + ---------- + batch : Sample + Input batch containing source and target. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + Prediction, target, and loss. + """ x = batch["source"] target = batch["target"] pred = self.forward(x) @@ -620,6 +871,23 @@ def forward_fit_supervised(self, batch: Sample) -> tuple[Tensor, Tensor, Tensor] def forward_fit_task( self, batch: Sample, batch_idx: int ) -> tuple[Tensor, Tensor | None, Tensor]: + """Forward pass for current training task. + + Automatically selects FCMAE pre-training or supervised training + based on model configuration. + + Parameters + ---------- + batch : Sample + Input batch. + batch_idx : int + Batch index. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor or None, torch.Tensor] + Prediction, target, and loss. + """ if self.model.pretraining: if batch_idx < self.log_batches_per_epoch: return_target = True @@ -630,6 +898,18 @@ def forward_fit_task( @torch.no_grad() def train_transform_and_collate(self, batch: list[dict[str, Tensor]]) -> Sample: + """Apply training transforms and collate batch data. + + Parameters + ---------- + batch : list[dict[str, torch.Tensor]] + List of batch dictionaries from multiple data modules. + + Returns + ------- + Sample + Collated and transformed sample. + """ transformed = [] for dataset_batch, dm in zip(batch, self.datamodules): dataset_batch = dm.train_gpu_transforms(dataset_batch) @@ -642,10 +922,38 @@ def train_transform_and_collate(self, batch: list[dict[str, Tensor]]) -> Sample: def val_transform_and_collate( self, batch: list[Sample], dataloader_idx: int ) -> Tensor: + """Apply validation transforms and collate batch data. + + Parameters + ---------- + batch : list[Sample] + List of samples. + dataloader_idx : int + Index of the validation dataloader. + + Returns + ------- + torch.Tensor + Collated and transformed batch. + """ batch = self.datamodules[dataloader_idx].val_gpu_transforms(batch) return collate_meta_tensor(batch) def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: + """Execute single training step for FCMAE. + + Parameters + ---------- + batch : list[list[Sample]] + Nested list of samples from multiple data modules. + batch_idx : int + Batch index. + + Returns + ------- + torch.Tensor + Training loss. + """ batch = self.train_transform_and_collate(batch) pred, target, loss = self.forward_fit_task(batch, batch_idx) if batch_idx < self.log_batches_per_epoch: @@ -669,6 +977,17 @@ def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: def validation_step( self, batch: list[Sample], batch_idx: int, dataloader_idx: int = 0 ) -> None: + """Execute single validation step for FCMAE. + + Parameters + ---------- + batch : list[Sample] + List of validation samples. + batch_idx : int + Batch index. + dataloader_idx : int, default=0 + Dataloader index. + """ batch = self.val_transform_and_collate(batch, dataloader_idx) pred, target, loss = self.forward_fit_task(batch, batch_idx) if dataloader_idx + 1 > len(self.validation_losses): diff --git a/viscy/translation/evaluation.py b/viscy/translation/evaluation.py index 11812f4fe..78376548f 100644 --- a/viscy/translation/evaluation.py +++ b/viscy/translation/evaluation.py @@ -5,7 +5,6 @@ from lightning.pytorch import LightningModule from torchmetrics.functional import accuracy, jaccard_index from torchmetrics.functional.segmentation import dice_score - from viscy.data.typing import SegmentationSample from viscy.translation.evaluation_metrics import mean_average_precision @@ -20,6 +19,15 @@ def __init__(self, aggregate_epoch: bool = False) -> None: self.aggregate_epoch = aggregate_epoch def test_step(self, batch: SegmentationSample, batch_idx: int) -> None: + """Compute segmentation metrics for a test batch. + + Parameters + ---------- + batch : SegmentationSample + Batch containing prediction and target segmentation masks + batch_idx : int + Batch index + """ pred = batch["pred"] target = batch["target"] if not pred.shape[0] == 1 and target.shape[0] == 1: diff --git a/viscy/translation/evaluation_metrics.py b/viscy/translation/evaluation_metrics.py index bb89858f2..598b52a69 100644 --- a/viscy/translation/evaluation_metrics.py +++ b/viscy/translation/evaluation_metrics.py @@ -1,6 +1,7 @@ -"""Metrics for model evaluation""" +"""Metrics for model evaluation.""" -from typing import Sequence, Union +from collections.abc import Sequence +from typing import Union from warnings import warn import numpy as np @@ -14,11 +15,22 @@ def VOI_metric(target, prediction): - """variation of information metric - Reports overlap between predicted and ground truth mask - : param np.array target: ground truth mask - : param np.array prediction: model infered FL image cellpose mask - : return float VI: VI for image masks + """ + Variation of information metric. + + Reports overlap between predicted and ground truth mask. + + Parameters + ---------- + target : np.array + Ground truth mask. + prediction : np.array + Model inferred FL image cellpose mask. + + Returns + ------- + list of float + VI for image masks. """ # cellpose segmentation of predicted image: outputs labl mask pred_bin = prediction > 0 @@ -56,6 +68,21 @@ def VOI_metric(target, prediction): def POD_metric(target_bin, pred_bin): + """ + Probability of detection metric for object matching. + + Parameters + ---------- + target_bin : array_like + Binary ground truth mask. + pred_bin : array_like + Binary predicted mask. + + Returns + ------- + tuple + POD and various detection statistics. + """ # pred_bin = cpmask_array(prediction) # relabel mask for ordered labelling across images for efficient LAP mapping @@ -100,129 +127,198 @@ def POD_metric(target_bin, pred_bin): matching_targ.append(rid) matching_pred.append(cid) - true_positives = len(matching_pred) - false_positives = n_predObj - len(matching_pred) - false_negatives = n_targObj - len(matching_targ) - precision = true_positives / (true_positives + false_positives) - recall = true_positives / (true_positives + false_negatives) - f1_score = 2 * (precision * recall / (precision + recall)) - - return [ - true_positives, - false_positives, - false_negatives, - precision, - recall, - f1_score, - ] - - -def labels_to_masks(labels: torch.ShortTensor) -> torch.BoolTensor: - """Convert integer labels to a stack of boolean masks. - - :param torch.ShortTensor labels: 2D labels where each value is an object - (0 is background) - :return torch.BoolTensor: Boolean masks of shape (objects, H, W) - """ - if labels.ndim != 2: - raise ValueError(f"Labels must be 2D, got shape {labels.shape}.") - segments = torch.unique(labels) - n_instances = segments.numel() - 1 - masks = torch.zeros( - (n_instances, *labels.shape), dtype=torch.bool, device=labels.device - ) - # TODO: optimize this? - for s, segment in enumerate(segments): - # start from label value 1, i.e. skip background label - masks[s - 1] = labels == segment - return masks + # probability of detection + POD = len(matching_targ) / len(props_targ) + # probability of false alarm + FAR = (len(props_pred) - len(matching_pred)) / len(props_pred) -def labels_to_detection(labels: torch.ShortTensor) -> dict[str, torch.Tensor]: - """Convert integer labels to a torchvision/torchmetrics detection dictionary. + # probability of correct detection + PCD = len(matching_targ) / len(props_targ) - :param torch.ShortTensor labels: 2D labels where each value is an object - (0 is background) - :return dict[str, torch.Tensor]: detection boxes, scores, labels, and masks - """ - masks = labels_to_masks(labels) - boxes = masks_to_boxes(masks) - return { - "boxes": boxes, - # dummy confidence scores - "scores": torch.ones( - (boxes.shape[0],), dtype=torch.float32, device=boxes.device - ), - # dummy class labels - "labels": torch.zeros( - (boxes.shape[0],), dtype=torch.uint8, device=boxes.device - ), - "masks": masks, - } - - -def mean_average_precision( - pred_labels: torch.ShortTensor, target_labels: torch.ShortTensor, **kwargs -) -> dict[str, torch.Tensor]: - """Compute the mAP metric for instance segmentation. - - :param torch.ShortTensor pred_labels: 2D integer prediction labels - :param torch.ShortTensor target_labels: 2D integer prediction labels - :param dict **kwargs: keyword arguments passed to - :py:class:`torchmetrics.detection.MeanAveragePrecision` - :return dict[str, torch.Tensor]: COCO-style metrics - """ - defaults = dict( - iou_type="segm", box_format="xyxy", max_detection_thresholds=[1, 100, 10000] - ) - if not kwargs: - kwargs = {} - map_metric = MeanAveragePrecision(**(defaults | kwargs)) - map_metric.update( - [labels_to_detection(pred_labels)], [labels_to_detection(target_labels)] - ) - return map_metric.compute() + return [POD, FAR, PCD, len(props_targ), len(props_pred)] + + +def compute_3d_dice_score( + y_true: torch.Tensor, + y_pred: torch.Tensor, + eps: float = 1e-8, + threshold: float = 0.5, + aggregate: bool = True, +) -> torch.Tensor: + """Compute 3D Dice similarity coefficient.""" + y_pred_thresholded = (y_pred > threshold).float() + intersection = torch.sum(y_true * y_pred_thresholded, dim=(-3, -2, -1)) + total = torch.sum(y_true + y_pred_thresholded, dim=(-3, -2, -1)) + dice = (2.0 * intersection + eps) / (total + eps) + if aggregate: + return torch.mean(dice) + return dice + + +def compute_jaccard_index( + y_true: torch.Tensor, + y_pred: torch.Tensor, + threshold: float = 0.5, +) -> torch.Tensor: + """Compute Jaccard index (IoU).""" + y_pred_thresholded = y_pred > threshold + intersection = torch.sum(y_true & y_pred_thresholded, dim=(-3, -2, -1)) + union = torch.sum(y_true | y_pred_thresholded, dim=(-3, -2, -1)) + return torch.mean(intersection.float() / union.float()) -def ssim_25d( +def compute_pearson_correlation_coefficient( + y_true: torch.Tensor, y_pred: torch.Tensor, dim: Sequence[int] | None = None +) -> torch.Tensor: + """Compute Pearson correlation coefficient.""" + if dim is None: + # default to spatial dimensions + dim = (-3, -2, -1) + y_true_centered = y_true - torch.mean(y_true, dim=dim, keepdim=True) + y_pred_centered = y_pred - torch.mean(y_pred, dim=dim, keepdim=True) + numerator = torch.sum(y_true_centered * y_pred_centered, dim=dim) + # compute stds + y_true_std = torch.sqrt(torch.sum(y_true_centered**2, dim=dim)) + y_pred_std = torch.sqrt(torch.sum(y_pred_centered**2, dim=dim)) + denominator = y_true_std * y_pred_std + # torch.full_like makes the entire tensor have the same value, + # so we have to use torch.full instead + small_correlation = torch.abs(denominator) < 1e-8 + pcc = torch.where( + small_correlation, torch.zeros_like(numerator), numerator / denominator + ) + return torch.mean(pcc) + + +class MeanAveragePrecisionNuclei(MeanAveragePrecision): + """Mean Average Precision for nuclei detection.""" + + def __init__(self, min_area: int = 20, iou_threshold: float = 0.5) -> None: + super().__init__(iou_thresholds=[iou_threshold]) + self.min_area = min_area + + def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute mean average precision for nuclei detection. + + Parameters + ---------- + prediction : torch.Tensor + Predicted nuclei segmentation masks. + target : torch.Tensor + Ground truth nuclei segmentation masks. + + Returns + ------- + torch.Tensor + Mean average precision score. + """ + prediction_labels = label(prediction > 0.5) + target_labels = label(target > 0.5) + device = prediction.device + preds = [] + targets = [] + for i, (pred_img, target_img) in enumerate( + zip(prediction_labels, target_labels) + ): + pred_props = regionprops(pred_img) + # binary mask for each instance + pred_masks = torch.zeros( + len(pred_props), *pred_img.shape, dtype=torch.bool, device=device + ) + pred_labels = torch.zeros(len(pred_props), dtype=torch.long, device=device) + pred_scores = torch.ones(len(pred_props), dtype=torch.float, device=device) + for j, prop in enumerate(pred_props): + if prop.area < self.min_area: + continue + pred_masks[j, pred_img == prop.label] = True + pred_labels[j] = 1 # class 1 for nuclei + + target_props = regionprops(target_img) + target_masks = torch.zeros( + len(target_props), *target_img.shape, dtype=torch.bool, device=device + ) + target_labels = torch.zeros( + len(target_props), dtype=torch.long, device=device + ) + for j, prop in enumerate(target_props): + if prop.area < self.min_area: + continue + target_masks[j, target_img == prop.label] = True + target_labels[j] = 1 + + preds.append( + { + "masks": pred_masks, + "labels": pred_labels, + "scores": pred_scores, + } + ) + targets.append({"masks": target_masks, "labels": target_labels}) + return super().__call__(preds, targets) + + +def ssim_loss_25d( preds: torch.Tensor, target: torch.Tensor, in_plane_window_size: tuple[int, int] = (11, 11), return_contrast_sensitivity: bool = False, -) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - """Multi-scale SSIM loss function for 2.5D volumes (3D with small depth). +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + """ + Multi-scale SSIM loss function for 2.5D volumes (3D with small depth). + Uses uniform kernel (windows), depth-dimension window size equals to depth size. - :param torch.Tensor preds: predicted batch (B, C, D, W, H) - :param torch.Tensor target: target batch - :param tuple[int, int] in_plane_window_size: kernel width and height, - by default (11, 11) - :param bool return_contrast_sensitivity: whether to return contrast sensitivity - :return torch.Tensor: SSIM for the batch - :return Optional[torch.Tensor]: contrast sensitivity + Parameters + ---------- + preds : torch.Tensor + Predicted batch (B, C, D, W, H). + target : torch.Tensor + Target batch. + in_plane_window_size : tuple[int, int], optional + Kernel width and height, by default (11, 11). + return_contrast_sensitivity : bool, optional + Whether to return contrast sensitivity, by default False. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + SSIM for the batch, optionally with contrast sensitivity. """ if preds.ndim != 5: raise ValueError( - f"Input shape must be (B, C, D, W, H), got input shape {preds.shape}" + f"Expected preds to have 5 dimensions (B, C, D, W, H), got {preds.ndim}" ) - depth = preds.shape[2] - if depth > 15: - warn(f"Input depth {depth} is potentially too large for 2.5D SSIM.") - ssim_img, cs_img = compute_ssim_and_cs( - preds, - target, - 3, - kernel_sigma=None, - kernel_size=(depth, *in_plane_window_size), - data_range=target.max(), - kernel_type="uniform", - ) - # aggregate to one scalar per batch - ssim = ssim_img.view(ssim_img.shape[0], -1).mean(1) + if preds.shape != target.shape: + raise ValueError( + f"Expected preds and target to have the same shape, " + f"got {preds.shape} and {target.shape}" + ) + + B, C, D, H, W = preds.shape + # Compute SSIM for each channel and each depth slice + ssim_per_channel = [] + cs_per_channel = [] + + for c in range(C): + # Window size for depth dimension is the depth size + window_size = (*in_plane_window_size, D) + ssim, cs = compute_ssim_and_cs( + preds[:, c, :, :, :], target[:, c, :, :, :], window_size + ) + ssim_per_channel.append(ssim) + if return_contrast_sensitivity: + cs_per_channel.append(cs) + + # Average across channels + ssim_result = torch.mean(torch.stack(ssim_per_channel)) + if return_contrast_sensitivity: - return ssim, cs_img.view(cs_img.shape[0], -1).mean(1) - else: - return ssim + cs_result = torch.mean(torch.stack(cs_per_channel)) + return ssim_result, cs_result + + return ssim_result def ms_ssim_25d( @@ -232,7 +328,9 @@ def ms_ssim_25d( clamp: bool = False, betas: Sequence[float] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), ) -> torch.Tensor: - """Multi-scale SSIM for 2.5D volumes (3D with small depth). + """ + Multi-scale SSIM for 2.5D volumes (3D with small depth). + Uses uniform kernel (windows), depth-dimension window size equals to depth size. Depth dimension is not downsampled. @@ -240,32 +338,72 @@ def ms_ssim_25d( Original license: Copyright The Lightning team, http://www.apache.org/licenses/LICENSE-2.0 - :param torch.Tensor preds: predicted images - :param torch.Tensor target: target images - :param tuple[int, int] in_plane_window_size: kernel width and height, - defaults to (11, 11) - :param bool clamp: clamp to [1e-6, 1] for training stability when used in loss, - defaults to False - :param Sequence[float] betas: exponents of each resolution, - defaults to (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) - :return torch.Tensor: multi-scale SSIM + Parameters + ---------- + preds : torch.Tensor + Predicted images. + target : torch.Tensor + Target images. + in_plane_window_size : tuple[int, int], optional + Kernel width and height, defaults to (11, 11). + clamp : bool, optional + Clamp to [1e-6, 1] for training stability when used in loss, + defaults to False. + betas : Sequence[float], optional + Exponents of each resolution, + defaults to (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). + + Returns + ------- + torch.Tensor + Multi-scale SSIM. """ base_min = 1e-4 mcs_list = [] - for _ in range(len(betas)): - ssim, contrast_sensitivity = ssim_25d( - preds, target, in_plane_window_size, return_contrast_sensitivity=True - ) - if clamp: - contrast_sensitivity = contrast_sensitivity.clamp(min=base_min) - mcs_list.append(contrast_sensitivity) - # do not downsample along depth - preds = F.avg_pool3d(preds, (1, 2, 2)) - target = F.avg_pool3d(target, (1, 2, 2)) + ssim_list = [] + + B, C, D, H, W = preds.shape + + for c in range(C): + # Window size for depth dimension is the depth size + window_size = (*in_plane_window_size, D) + + pred_c = preds[:, c] + target_c = target[:, c] + + for level in range(len(betas)): + if level > 0: + # Downsample only in spatial dimensions, not depth + pred_c = F.avg_pool2d(pred_c.view(-1, H, W), kernel_size=2).view( + B, D, H // 2, W // 2 + ) + target_c = F.avg_pool2d(target_c.view(-1, H, W), kernel_size=2).view( + B, D, H // 2, W // 2 + ) + H, W = H // 2, W // 2 + + ssim, cs = compute_ssim_and_cs(pred_c, target_c, window_size) + + if level == len(betas) - 1: + ssim_list.append(ssim) + else: + mcs_list.append(cs) + + # Compute the final ms-ssim score + mcs_tensor = torch.stack(mcs_list) + ssim_tensor = torch.stack(ssim_list) + + # Apply betas weighting + betas_tensor = torch.tensor(betas, device=preds.device, dtype=preds.dtype) + + # For numerical stability if clamp: - ssim = ssim.clamp(min=base_min) - mcs_list[-1] = ssim - mcs_stack = torch.stack(mcs_list) - betas = torch.tensor(betas, device=mcs_stack.device).view(-1, 1) - mcs_weighted = mcs_stack**betas - return torch.prod(mcs_weighted, axis=0).mean() + mcs_tensor = torch.clamp(mcs_tensor, base_min, 1) + ssim_tensor = torch.clamp(ssim_tensor, base_min, 1) + + # Compute weighted geometric mean + ms_ssim_val = torch.prod(mcs_tensor ** betas_tensor[:-1]) * ( + ssim_tensor ** betas_tensor[-1] + ) + + return torch.mean(ms_ssim_val) diff --git a/viscy/translation/predict_writer.py b/viscy/translation/predict_writer.py index 75d6d4152..96d34971d 100644 --- a/viscy/translation/predict_writer.py +++ b/viscy/translation/predict_writer.py @@ -1,7 +1,10 @@ +"""Prediction writer for HCS virtual staining predictions in OME-Zarr format.""" + import logging import os +from collections.abc import Sequence from pathlib import Path -from typing import Literal, Optional, Sequence +from typing import Literal, Optional import numpy as np import torch @@ -9,7 +12,6 @@ from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import DTypeLike, NDArray - from viscy.data.hcs import HCSDataModule, Sample __all__ = ["HCSPredictionWriter"] @@ -19,6 +21,7 @@ def _pad_shape(shape: tuple[int, ...], target: int = 5) -> tuple[int, ...]: """ Pad shape tuple to a target length. + Vendored from ``iohub.ngff.nodes._pad_shape()``. """ pad = target - len(shape) @@ -49,7 +52,7 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray weights are determined by the position within the range of slices. If the start of `z_slice` is 0, the function returns the `new_stack` unchanged. - Parameters: + Parameters ---------- old_stack : NDArray The original stack of images to be blended. @@ -59,12 +62,11 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray A slice object indicating the range of slices over which to perform the blending. The start and stop attributes of the slice determine the range. - Returns: + Returns ------- NDArray The blended stack of images. If `z_slice.start` is 0, returns `new_stack` unchanged. """ - if z_slice.start == 0: return new_stack depth = z_slice.stop - z_slice.start @@ -81,12 +83,15 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray class HCSPredictionWriter(BasePredictionWriter): """Callback to store virtual staining predictions as HCS OME-Zarr. - :param str output_store: Path to the zarr store to store output - :param bool write_input: Write the source and target channels too - (must be writing to a new store), - defaults to False - :param Literal['batch', 'epoch', 'batch_and_epoch'] write_interval: - When to write, defaults to "batch" + Parameters + ---------- + output_store : str + Path to the zarr store to store output. + write_input : bool, optional + Write the source and target channels too (must be writing to a new store), + by default False. + write_interval : Literal['batch', 'epoch', 'batch_and_epoch'], optional + When to write, by default "batch". """ def __init__( @@ -117,6 +122,16 @@ def _get_scale_metadata(self, metadata_store: Path) -> None: _logger.debug(f"Dataset scale {self._dataset_scale}.") def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + """ + Initialize output store and set up prediction writing at start of prediction. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer instance. + pl_module : LightningModule + PyTorch Lightning module being used for predictions. + """ dm: HCSDataModule = trainer.datamodule self._get_scale_metadata(dm.data_path) self.z_padding = dm.z_window_size // 2 if dm.target_2d else 0 @@ -156,21 +171,63 @@ def write_on_batch_end( trainer: Trainer, pl_module: LightningModule, prediction: torch.Tensor, - batch_indices: Optional[Sequence[int]], + batch_indices: Sequence[int] | None, batch: Sample, batch_idx: int, dataloader_idx: int, ) -> None: + """ + Write predictions to output store at the end of each batch. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer instance. + pl_module : LightningModule + PyTorch Lightning module being used for predictions. + prediction : torch.Tensor + Batch of predictions from the model. + batch_indices : Optional[Sequence[int]] + Indices of the batch samples. + batch : Sample + Input batch data. + batch_idx : int + Index of the current batch. + dataloader_idx : int + Index of the current dataloader. + """ _logger.debug(f"Writing batch {batch_idx}.") for sample_index, _ in enumerate(batch["index"][0]): self.write_sample(batch, prediction[sample_index], sample_index) def on_predict_end(self, trainer: Trainer, pl_module: LightningModule) -> None: + """ + Close output store at the end of prediction. + + Parameters + ---------- + trainer : Trainer + PyTorch Lightning trainer instance. + pl_module : LightningModule + PyTorch Lightning module being used for predictions. + """ self.plate.close() def write_sample( self, batch: Sample, sample_prediction: torch.Tensor, sample_index: int ) -> None: + """ + Write a single sample prediction to the output store. + + Parameters + ---------- + batch : Sample + Input batch data containing metadata for the sample. + sample_prediction : torch.Tensor + Prediction tensor for the sample. + sample_index : int + Index of the sample within the batch. + """ _logger.debug(f"Writing sample {sample_index}.") sample_prediction = sample_prediction.cpu().numpy() img_name, t_index, z_index = [batch["index"][i][sample_index] for i in range(3)] diff --git a/viscy/unet/networks/Unet25D.py b/viscy/unet/networks/Unet25D.py index 8a34042d7..802bec409 100644 --- a/viscy/unet/networks/Unet25D.py +++ b/viscy/unet/networks/Unet25D.py @@ -1,3 +1,5 @@ +from typing import Literal + import torch import torch.nn as nn @@ -5,53 +7,68 @@ class Unet25d(nn.Module): - def __name__(self): + """2.5D U-Net neural network for volumetric image translation. + + A hybrid approach that processes 3D input stacks but outputs 2D predictions. + Combines 3D spatial information with 2D computational efficiency. + """ + + def __name__(self) -> str: return "Unet25d" def __init__( self, - in_channels=1, - out_channels=1, - in_stack_depth=5, - out_stack_depth=1, - xy_kernel_size=(3, 3), - residual=False, - dropout=0.2, - num_blocks=4, - num_block_layers=2, - num_filters=[], - task="seg", - ): - """ - Instance of 2.5D Unet. - 1.) https://elifesciences.org/articles/55502 - - Architecture takes in stack of 2d inputs given as a 3d tensor - and returns a 2d interpretation. - Learns 3d information based upon input stack, - but speeds up training by compressing 3d information before the decoding path. - Uses interruption conv layers in the Unet skip paths to + in_channels: int = 1, + out_channels: int = 1, + in_stack_depth: int = 5, + out_stack_depth: int = 1, + xy_kernel_size: tuple[int, int] = (3, 3), + residual: bool = False, + dropout: float = 0.2, + num_blocks: int = 4, + num_block_layers: int = 2, + num_filters: list[int] = [], + task: Literal["seg", "reg"] = "seg", + ) -> None: + """Initialize 2.5D U-Net. + + Architecture takes in stack of 2D inputs given as a 3D tensor + and returns a 2D interpretation. Learns 3D information based upon input stack, + but speeds up training by compressing 3D information before the decoding path. + Uses interruption conv layers in the U-Net skip paths to compress information with z-channel convolution. - :param int in_channels: number of feature channels in (1 or more) - :param int out_channels: number of feature channels out (1 or more) - :param int input_stack_depth: depth of input stack in z - :param int output_stack_depth: depth of output stack - :param int/tuple(int, int) xy_kernel_size: size of x and y dimensions - of conv kernels in blocks - :param bool residual: see name - :param float dropout: probability of dropout, between 0 and 0.5 - :param int num_blocks: number of convolutional blocks - on encoder and decoder paths - :param int num_block_layers: number of layer sequences repeated per block - :param list[int] num_filters: list of filters/feature levels - at each conv block depth - :param str task: network task (for virtual staining this is regression), - one of 'seg','reg' - :param str debug_mode: if true logs features at each step of architecture, - must be manually set + References + ---------- + https://elifesciences.org/articles/55502 + + Parameters + ---------- + in_channels : int, optional + Number of feature channels in (1 or more), by default 1. + out_channels : int, optional + Number of feature channels out (1 or more), by default 1. + in_stack_depth : int, optional + Depth of input stack in z, by default 5. + out_stack_depth : int, optional + Depth of output stack, by default 1. + xy_kernel_size : int or tuple of int, optional + Size of x and y dimensions of conv kernels in blocks, by default (3, 3). + residual : bool, optional + Whether to use residual connections, by default False. + dropout : float, optional + Probability of dropout, between 0 and 0.5, by default 0.2. + num_blocks : int, optional + Number of convolutional blocks on encoder and decoder paths, by default 4. + num_block_layers : int, optional + Number of layer sequences repeated per block, by default 2. + num_filters : list of int, optional + List of filters/feature levels at each conv block depth, by default []. + task : str, optional + Network task (for virtual staining this is regression), + one of 'seg','reg', by default "seg". """ - super(Unet25d, self).__init__() + super().__init__() self.in_channels = in_channels self.num_blocks = num_blocks self.kernel_size = xy_kernel_size @@ -202,7 +219,7 @@ def __init__( # ----- Feature Logging ----- # self.log_save_folder = None - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward call of network. @@ -215,7 +232,6 @@ def forward(self, x): :param torch.tensor x: input image """ - # encoder skip_tensors = [] for i in range(self.num_blocks): @@ -240,16 +256,20 @@ def forward(self, x): x = self.terminal_block(x) return x - def register_modules(self, module_list, name): - """ - Helper function that registers modules stored in a list to the model object - so that the can be seen by PyTorch optimizer. + def register_modules(self, module_list: list[nn.Module], name: str) -> None: + """Helper function that registers modules stored in a list to the model object. + + So that they can be seen by PyTorch optimizer. Used to enable model graph creation with - non-sequential model types and dynamic layer numbers + non-sequential model types and dynamic layer numbers. - :param list(torch.nn.module) module_list: list of modules to register - :param str name: name of module type + Parameters + ---------- + module_list : list[torch.nn.module] + List of modules to register + name : str + Name of module type """ for i, module in enumerate(module_list): self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/Unet2D.py b/viscy/unet/networks/Unet2D.py index 0edd95362..e454942d7 100644 --- a/viscy/unet/networks/Unet2D.py +++ b/viscy/unet/networks/Unet2D.py @@ -5,6 +5,12 @@ class Unet2d(nn.Module): + """2D U-Net neural network for image-to-image translation. + + A convolutional neural network following the U-Net architecture for 2D images. + Supports both segmentation and regression tasks with configurable depth and filters. + """ + def __name__(self): return "Unet2d" @@ -20,27 +26,38 @@ def __init__( num_filters=[], task="seg", ): - """ - 2D Unet with variable input/output channels and depth (block numbers). + """Initialize 2D U-Net with variable input/output channels and depth. + Follows 2D UNet Architecture: - 1) Unet: https://arxiv.org/pdf/1505.04597.pdf - 2) residual Unet: https://arxiv.org/pdf/1711.10684.pdf - - :param int in_channels: number of feature channels in - :param int out_channels: number of feature channels out - :param int/tuple(int,int) kernel_size: size of x and y dimensions - of conv kernels in blocks - :param bool residual: see name - :param float dropout: probability of dropout, between 0 and 0.5 - :param int num_blocks: number of convolutional blocks on encoder and decoder - :param int num_block_layers: number of layers per block - :param list[int] num_filters: list of filters/feature levels - at each conv block depth - :param str task: network task (for virtual staining this is regression), - one of 'seg','reg' - """ - super(Unet2d, self).__init__() + References + ---------- + 1) U-Net: https://arxiv.org/pdf/1505.04597.pdf + 2) Residual U-Net: https://arxiv.org/pdf/1711.10684.pdf + + Parameters + ---------- + in_channels : int, optional + Number of feature channels in, by default 1. + out_channels : int, optional + Number of feature channels out, by default 1. + kernel_size : int or tuple of int, optional + Size of x and y dimensions of conv kernels in blocks, by default (3, 3). + residual : bool, optional + Whether to use residual connections, by default False. + dropout : float, optional + Probability of dropout, between 0 and 0.5, by default 0.2. + num_blocks : int, optional + Number of convolutional blocks on encoder and decoder, by default 4. + num_block_layers : int, optional + Number of layers per block, by default 2. + num_filters : list of int, optional + List of filters/feature levels at each conv block depth, by default []. + task : str, optional + Network task (for virtual staining this is regression), + one of 'seg','reg', by default "seg". + """ + super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size @@ -168,19 +185,26 @@ def __init__( ) def forward(self, x, validate_input=False): - """ - Forward call of network - - x -> Torch.tensor: input image stack + """Forward pass through the 2D U-Net. Call order: - => num_block 2D convolutional blocks, with downsampling in between (encoder) - => num_block 2D convolutional blocks, with upsampling between them (decoder) - => skip connections between corresponding blocks on encoder and decoder - => terminal block collapses to output dimensions - - :param torch.tensor x: input image - :param bool validate_input: Deactivates assertions which are redundant - if forward pass is being traced by tensorboard writer. + => num_block 2D convolutional blocks, with downsampling in between (encoder) + => num_block 2D convolutional blocks, with upsampling between them (decoder) + => skip connections between corresponding blocks on encoder and decoder + => terminal block collapses to output dimensions + + Parameters + ---------- + x : torch.tensor + Input image stack. + validate_input : bool, optional + Deactivates assertions which are redundant if forward pass is being + traced by tensorboard writer, by default False. + + Returns + ------- + torch.tensor + Network output with same spatial dimensions as input. """ # handle input exceptions if validate_input: @@ -211,15 +235,19 @@ def forward(self, x, validate_input=False): return x.unsqueeze(2) def register_modules(self, module_list, name): - """ - Helper function that registers modules stored in a list to the model object - so that they can be seen by PyTorch optimizer. + """Helper function that registers modules stored in a list to the model object. - Used to enable model graph creation with - non-sequential model types and dynamic layer numbers + So that they can be seen by PyTorch optimizer. - :param list(torch.nn.module) module_list: list of modules to register - :param str name: name of module type + Used to enable model graph creation with + non-sequential model types and dynamic layer numbers. + + Parameters + ---------- + module_list : list[torch.nn.module] + List of modules to register + name : str + Name of module type """ for i, module in enumerate(module_list): self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index d63b65a7d..c529b0a13 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -1,12 +1,12 @@ -""" -Fully Convolutional Masked Autoencoder as described in ConvNeXt V2 -based on the official JAX example in +"""Fully Convolutional Masked Autoencoder as described in ConvNeXt V2. + +Based on the official JAX example in https://github.com/facebookresearch/ConvNeXt-V2/blob/main/TRAINING.md#implementing-fcmae-with-masked-convolution-in-jax and timm's dense implementation of the encoder in ``timm.models.convnext`` """ import math -from typing import Sequence +from collections.abc import Sequence import torch from monai.networks.blocks import UpSample @@ -40,7 +40,8 @@ def _init_weights(module: nn.Module) -> None: def generate_mask( target: Size, stride: int, mask_ratio: float, device: str ) -> BoolTensor: - """ + """Generate random boolean mask for masked autoencoder training. + :param Size target: target shape :param int stride: total stride :param float mask_ratio: ratio of the pixels to mask @@ -55,7 +56,8 @@ def generate_mask( def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: - """ + """Upsample boolean mask to match target spatial dimensions. + :param BoolTensor mask: low-resolution boolean mask (B1HW) :param Size target: target size (BCHW) :return BoolTensor: upsampled boolean mask (B1HW) @@ -73,7 +75,8 @@ def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: def masked_patchify(features: Tensor, unmasked: BoolTensor | None = None) -> Tensor: - """ + """Convert spatial features to channel-last patches, optionally masked. + :param Tensor features: input image features (BCHW) :param BoolTensor unmasked: boolean foreground mask (B1HW) :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) @@ -91,7 +94,8 @@ def masked_patchify(features: Tensor, unmasked: BoolTensor | None = None) -> Ten def masked_unpatchify( features: Tensor, out_shape: Size, unmasked: BoolTensor | None = None ) -> Tensor: - """ + """Convert channel-last patches back to spatial features. + :param Tensor features: dense channel-last features (BLC) :param Size out_shape: output shape (BCHW) :param BoolTensor | None unmasked: boolean foreground mask, defaults to None @@ -151,7 +155,8 @@ def __init__( self.shortcut = nn.Identity() def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: - """ + """Forward pass through masked ConvNeXt V2 block. + :param Tensor x: input tensor (BCHW) :param BoolTensor | None unmasked: boolean foreground mask, defaults to None :return Tensor: output tensor (BCHW) @@ -229,7 +234,8 @@ def __init__( in_channels = out_channels def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: - """ + """Forward pass through masked ConvNeXt V2 stage. + :param Tensor x: input tensor (BCHW) :param BoolTensor | None unmasked: boolean foreground mask, defaults to None :return Tensor: output tensor (BCHW) @@ -281,7 +287,8 @@ def __init__( self.norm = nn.LayerNorm(out_channels) def forward(self, x: Tensor, unmasked: BoolTensor = None) -> Tensor: - """ + """Forward pass through masked adaptive projection layer. + :param Tensor x: input tensor (BCDHW) :param BoolTensor unmasked: boolean foreground mask (B1HW), defaults to None :return Tensor: output tensor (BCHW) @@ -305,6 +312,19 @@ def forward(self, x: Tensor, unmasked: BoolTensor = None) -> Tensor: class MaskedMultiscaleEncoder(nn.Module): + """Multi-scale encoder with masking support for FC-MAE architecture. + + Implements hierarchical feature extraction through multiple ConvNeXt V2 stages + with optional random masking for self-supervised pretraining. + + :param int in_channels: input channels + :param Sequence[int] stage_blocks: number of blocks per encoder stage + :param Sequence[int] dims: feature dimensions at each stage + :param float drop_path_rate: stochastic depth rate + :param Sequence[int] stem_kernel_size: kernel sizes for adaptive projection + :param int in_stack_depth: input stack depth for 3D input + """ + def __init__( self, in_channels: int, @@ -342,7 +362,8 @@ def __init__( def forward( self, x: Tensor, mask_ratio: float = 0.0 ) -> tuple[list[Tensor], BoolTensor | None]: - """ + """Extract multi-scale features with optional masking. + :param Tensor x: input tensor (BCDHW) :param float mask_ratio: ratio of the feature maps to mask, defaults to 0.0 (no masking) @@ -367,6 +388,18 @@ def forward( class PixelToVoxelShuffleHead(nn.Module): + """Pixel-to-voxel reconstruction head using pixel shuffle upsampling. + + Converts 2D feature maps to 3D output volumes through pixel shuffle + upsampling and channel-to-depth reshaping. + + :param int in_channels: input feature channels + :param int out_channels: output channels per voxel + :param int out_stack_depth: output stack depth (Z dimension) + :param int xy_scaling: spatial upsampling factor + :param bool pool: whether to apply pooling in upsampling + """ + def __init__( self, in_channels: int, @@ -389,6 +422,11 @@ def __init__( ) def forward(self, x: Tensor) -> Tensor: + """Reconstruct 3D volume from 2D features. + + :param Tensor x: input 2D features (BCHW) + :return Tensor: reconstructed 3D volume (BCDHW) + """ x = self.upsample(x) b, _, h, w = x.shape x = x.reshape(b, self.out_channels, self.out_stack_depth, h, w) @@ -396,6 +434,28 @@ def forward(self, x: Tensor) -> Tensor: class FullyConvolutionalMAE(nn.Module): + """Fully Convolutional Masked Autoencoder for self-supervised learning. + + Implements FC-MAE architecture combining a masked multi-scale encoder + with a UNet-style decoder for reconstruction tasks. Supports both + pretraining with masking and fine-tuning for downstream tasks. + + # TODO: MANUAL_REVIEW - Complex encoder-decoder architecture with masking + + :param int in_channels: input channels + :param int out_channels: output channels + :param Sequence[int] encoder_blocks: blocks per encoder stage + :param Sequence[int] dims: feature dimensions per stage + :param float encoder_drop_path_rate: encoder stochastic depth rate + :param Sequence[int] stem_kernel_size: adaptive projection kernel sizes + :param int in_stack_depth: input stack depth for 3D data + :param int decoder_conv_blocks: decoder convolution blocks per stage + :param bool pretraining: whether in pretraining mode (returns mask) + :param bool head_conv: whether to use convolutional reconstruction head + :param int head_conv_expansion_ratio: expansion ratio for conv head + :param bool head_conv_pool: whether to use pooling in conv head + """ + def __init__( self, in_channels: int, @@ -460,6 +520,15 @@ def __init__( self.pretraining = pretraining def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + """Forward pass through FC-MAE architecture. + + Encodes input with optional masking, decodes through UNet decoder, + and reconstructs output through pixel-to-voxel head. + + :param Tensor x: input tensor (BCDHW) + :param float mask_ratio: masking ratio for pretraining (0.0 = no mask) + :return Tensor: reconstructed output (BCDHW) or tuple with mask + """ x, mask = self.encoder(x, mask_ratio=mask_ratio) x.reverse() x = self.decoder(x) diff --git a/viscy/unet/networks/layers/ConvBlock2D.py b/viscy/unet/networks/layers/ConvBlock2D.py index 114777a79..c5f35c9e5 100644 --- a/viscy/unet/networks/layers/ConvBlock2D.py +++ b/viscy/unet/networks/layers/ConvBlock2D.py @@ -1,3 +1,7 @@ +"""2D convolutional blocks for U-Net architectures.""" + +from typing import Literal + import numpy as np import torch import torch.nn as nn @@ -5,22 +9,27 @@ class ConvBlock2D(nn.Module): + """2D convolutional block for U-Net lateral layers with configurable architecture. + + Supports dynamic layer configuration, normalization, activation functions, + residual connections, and various filter progression strategies. + """ + def __init__( self, - in_filters, - out_filters, - dropout=False, - norm="batch", - residual=True, - activation="relu", - transpose=False, - kernel_size=3, - num_repeats=3, - filter_steps="first", - layer_order="can", - ): - """ - Convolutional block for lateral layers in Unet + in_filters: int, + out_filters: int, + dropout: float | bool = False, + norm: Literal["batch", "instance"] = "batch", + residual: bool = True, + activation: Literal["relu", "leakyrelu", "elu", "selu", "linear"] = "relu", + transpose: bool = False, + kernel_size: int | tuple[int, int] = 3, + num_repeats: int = 3, + filter_steps: Literal["linear", "first", "last"] = "first", + layer_order: str = "can", + ) -> None: + """Initialize convolutional block for lateral layers in U-Net. Format for layer initialization is as follows: if layer type specified @@ -46,8 +55,7 @@ def __init__( :param str layer_order: order of conv, norm, and act layers in block: 'can', 'cna', 'nca', etc """ - - super(ConvBlock2D, self).__init__() + super().__init__() self.in_filters = in_filters self.out_filters = out_filters self.dropout = dropout @@ -262,9 +270,8 @@ def __init__( ) self.register_modules(self.act_list, f"{self.activation}_act") - def forward(self, x, validate_input=False): - """ - Forward call of convolutional block + def forward(self, x: torch.Tensor, validate_input: bool = False) -> torch.Tensor: + """Forward pass through the convolutional block. Order of layers within the block is defined by the 'layer_order' parameter, which is a string of 'c's, 'a's and 'n's @@ -335,19 +342,24 @@ def forward(self, x, validate_input=False): return x - def model(self): - """ + def model(self) -> nn.Sequential: + """Create a sequential model from the convolutional block layers. + Allows calling of parameters inside ConvBlock object: - 'ConvBlock.model().parameters()'' + 'ConvBlock.model().parameters()' - Layer order: convolution -> normalization -> activation + Layer order: convolution -> normalization -> activation We can make a list of layer modules and unpack them into nn.Sequential. - Note: this is distinct from the forward call - because we want to use the forward call with addition, - since this is a residual block. - The forward call performs the residial calculation, - and all the parameters can be seen by the optimizer when given this model. + Note: this is distinct from the forward call because we want to use + the forward call with addition, since this is a residual block. + The forward call performs the residual calculation, and all the + parameters can be seen by the optimizer when given this model. + + Returns + ------- + nn.Sequential + Sequential model containing all layers in the block. """ layers = [] @@ -362,16 +374,21 @@ def model(self): return nn.Sequential(*layers) - def register_modules(self, module_list, name): - """ + def register_modules(self, module_list: list[nn.Module], name: str) -> None: + """Register modules from a list to enable PyTorch optimizer access. + Helper function that registers modules stored in a list to the model object so that they can be seen by PyTorch optimizer. - Used to enable model graph creation - with non-sequential model types and dynamic layer numbers + Used to enable model graph creation with non-sequential model types + and dynamic layer numbers. - :param list(torch.nn.module) module_list: list of modules to register - :param str name: name of module type + Parameters + ---------- + module_list : list of torch.nn.Module + List of PyTorch modules to register. + name : str + Name prefix for the module type. """ for i, module in enumerate(module_list): self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/layers/ConvBlock3D.py b/viscy/unet/networks/layers/ConvBlock3D.py index 893c612ef..8f5339277 100644 --- a/viscy/unet/networks/layers/ConvBlock3D.py +++ b/viscy/unet/networks/layers/ConvBlock3D.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np import torch import torch.nn as nn @@ -5,23 +7,64 @@ class ConvBlock3D(nn.Module): + """3D convolutional building block for volumetric neural networks. + + A flexible 3D convolutional block designed for processing volumetric data + such as medical imaging, microscopy, and video sequences. Supports residual + connections, various normalization schemes, activation functions, and + configurable layer ordering for deep 3D U-Net architectures. + + The block processes tensors in [..., z, x, y] or [..., z, y, x] format + and provides dynamic layer configuration with support for transpose + convolutions, dropout, and multiple padding strategies optimized for + volumetric convolution operations. + + Parameters + ---------- + in_filters : int + Number of input feature channels. + out_filters : int + Number of output feature channels. + dropout : float or bool, default=False + Dropout probability. If False, no dropout is applied. + norm : {"batch", "instance"}, default="batch" + Normalization type to apply. + residual : bool, default=True + Whether to include residual connections. + activation : {"relu", "leakyrelu", "elu", "selu", "linear"}, default="relu" + Activation function type. + transpose : bool, default=False + Whether to use transpose convolution layers. + kernel_size : int or tuple of int, default=(3, 3, 3) + 3D convolutional kernel size. + num_repeats : int, default=3 + Number of convolutional layers in the block. + filter_steps : {"linear", "first", "last"}, default="first" + Strategy for channel dimension changes across layers. + layer_order : str, default="can" + Order of conv (c), activation (a), normalization (n) layers. + padding : str, int, tuple or None, default=None + Padding strategy for convolutions. + """ + def __init__( self, - in_filters, - out_filters, - dropout=False, - norm="batch", - residual=True, - activation="relu", - transpose=False, - kernel_size=(3, 3, 3), - num_repeats=3, - filter_steps="first", - layer_order="can", - padding=None, - ): + in_filters: int, + out_filters: int, + dropout: float | bool = False, + norm: Literal["batch", "instance"] = "batch", + residual: bool = True, + activation: Literal["relu", "leakyrelu", "elu", "selu", "linear"] = "relu", + transpose: bool = False, + kernel_size: int | tuple[int, int, int] = (3, 3, 3), + num_repeats: int = 3, + filter_steps: Literal["linear", "first", "last"] = "first", + layer_order: str = "can", + padding: str | int | tuple[int, ...] | None = None, + ) -> None: """ Convolutional block for lateral layers in Unet. + This block only accepts tensors of dimensions in order [...,z,x,y] or [...,z,y,x] @@ -61,8 +104,7 @@ def __init__( :paramn str/tuple(int)/tuple/None padding: convolutional padding, see docstring for details """ - - super(ConvBlock3D, self).__init__() + super().__init__() self.in_filters = in_filters self.out_filters = out_filters self.dropout = dropout @@ -244,7 +286,7 @@ def __init__( ) self.register_modules(self.act_list, f"{self.activation}_act") - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward call of convolutional block @@ -310,10 +352,9 @@ def forward(self, x): return x - def model(self): + def model(self) -> nn.Sequential: """ - Allows calling of parameters inside ConvBlock object: - 'ConvBlock.model().parameters()'' + Allows calling of parameters inside ConvBlock object. Layer order: convolution -> normalization -> activation @@ -337,10 +378,9 @@ def model(self): return nn.Sequential(*layers) - def register_modules(self, module_list, name): + def register_modules(self, module_list: list[nn.Module], name: str) -> None: """ - Helper function that registers modules stored in a list to the model object - so that the can be seen by PyTorch optimizer. + Helper function that registers modules for PyTorch optimizer visibility. Used to enable model graph creation with non-sequential model types and dynamic layer numbers diff --git a/viscy/unet/networks/unext2.py b/viscy/unet/networks/unext2.py index c2403fc9b..303bd0a67 100644 --- a/viscy/unet/networks/unext2.py +++ b/viscy/unet/networks/unext2.py @@ -1,4 +1,5 @@ -from typing import Callable, Literal, Sequence +from collections.abc import Callable, Sequence +from typing import Literal import timm import torch @@ -14,17 +15,22 @@ def icnr_init( upsample_dims: int, init: Callable = nn.init.kaiming_normal_, ): - """ - ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , - "Checkerboard artifact free sub-pixel convolution". + """ICNR initialization for 2D/3D kernels. + Adapted from Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". Adapted from MONAI v1.2.0, added support for upsampling dimensions that are not the same as the kernel dimension. - :param conv: convolution layer - :param upsample_factor: upsample factor - :param upsample_dims: upsample dimensions, 2 or 3 - :param init: initialization function + Parameters + ---------- + conv : nn.Module + Convolution layer to initialize. + upsample_factor : int + Upsample factor. + upsample_dims : int + Upsample dimensions, 2 or 3. + init : Callable, optional + Initialization function, by default nn.init.kaiming_normal_. """ out_channels, in_channels, *dims = conv.weight.shape scale_factor = upsample_factor**upsample_dims @@ -84,6 +90,19 @@ def __init__( ) def forward(self, x: Tensor): + """Forward pass through UNeXt2 stem with depth-to-channel projection. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, D, H, W) where D is the stack depth. + + Returns + ------- + torch.Tensor + Output tensor with depth projected to channels, shape (B, C*D', H', W') + where D' = D // kernel_size[0] after 3D convolution. + """ x = self.conv(x) b, c, d, h, w = x.shape # project Z/depth into channels @@ -117,6 +136,29 @@ def __init__( def compute_stem_channels( self, in_stack_depth, stem_kernel_size, stem_stride_depth, in_channels_encoder ): + """Compute required 3D stem output channels for encoder compatibility. + + Parameters + ---------- + in_stack_depth : int + Input stack depth dimension. + stem_kernel_size : tuple[int, int, int] + 3D convolution kernel size. + stem_stride_depth : int + Stride in the depth dimension. + in_channels_encoder : int + Required input channels for the encoder after depth projection. + + Returns + ------- + int + Required output channels for the 3D stem convolution. + + Raises + ------ + ValueError + If channel dimensions cannot be matched with current configuration. + """ stem3d_out_depth = ( in_stack_depth - stem_kernel_size[0] ) // stem_stride_depth + 1 @@ -129,6 +171,19 @@ def compute_stem_channels( return stem3d_out_channels def forward(self, x: Tensor): + """Forward pass with 3D convolution and depth-to-channel mapping. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, D, H, W) where D is the input stack depth. + + Returns + ------- + torch.Tensor + Output tensor with depth projected to channels, maintaining spatial + dimensions after strided 3D convolution. + """ x = self.conv(x) b, c, d, h, w = x.shape # project Z/depth into channels @@ -137,6 +192,16 @@ def forward(self, x: Tensor): class UNeXt2UpStage(nn.Module): + """UNeXt2 decoder upsampling stage with skip connection fusion. + + Implements hierarchical feature upsampling using either deconvolution or + pixel shuffle, followed by ConvNeXt blocks for feature refinement. Combines + low-resolution features with high-resolution skip connections for multi-scale + feature fusion. + + # TODO: MANUAL_REVIEW - ConvNeXt block integration with skip connections + """ + def __init__( self, in_channels: int, @@ -191,10 +256,20 @@ def __init__( ) def forward(self, inp: Tensor, skip: Tensor) -> Tensor: - """ - :param Tensor inp: Low resolution features - :param Tensor skip: High resolution skip connection features - :return Tensor: High resolution features + """Forward pass with upsampling and skip connection fusion. + + Parameters + ---------- + inp : torch.Tensor + Low resolution input features from deeper decoder stage. + skip : torch.Tensor + High resolution skip connection features from encoder. + + Returns + ------- + torch.Tensor + Upsampled and refined features combining both inputs through + ConvNeXt blocks or residual units. """ inp = self.upsample(inp) inp = torch.cat([inp, skip], dim=1) @@ -202,6 +277,15 @@ def forward(self, inp: Tensor, skip: Tensor) -> Tensor: class PixelToVoxelHead(nn.Module): + """Head module for converting 2D features to 3D voxel output. + + Performs 2D-to-3D reconstruction using pixel shuffle upsampling and 3D + convolutions. Applies depth channel expansion and spatial upsampling to + generate volumetric outputs from 2D feature representations. + + # TODO: MANUAL_REVIEW - 2D to 3D reconstruction mechanism + """ + def __init__( self, in_channels: int, @@ -238,6 +322,19 @@ def __init__( self.out_stack_depth = out_stack_depth def forward(self, x: Tensor) -> Tensor: + """Forward pass for 2D to 3D voxel reconstruction. + + Parameters + ---------- + x : torch.Tensor + Input 2D feature tensor of shape (B, C, H, W). + + Returns + ------- + torch.Tensor + Output 3D voxel tensor with upsampled spatial dimensions and + reconstructed depth, shape (B, out_channels, out_stack_depth, H', W'). + """ x = self.upsample(x) d = self.out_stack_depth + 2 b, c, h, w = x.shape @@ -255,11 +352,32 @@ def __init__(self) -> None: super().__init__() def forward(self, x: Tensor) -> Tensor: + """Forward pass adding singleton depth dimension. + + Parameters + ---------- + x : torch.Tensor + Input 2D tensor of shape (B, C, H, W). + + Returns + ------- + torch.Tensor + Output 3D tensor with singleton depth dimension, shape (B, C, 1, H, W). + """ x = x.unsqueeze(2) return x class UNeXt2Decoder(nn.Module): + """UNeXt2 hierarchical decoder with multi-stage upsampling. + + Implements progressive upsampling through multiple UNeXt2UpStage modules, + combining features from different encoder scales through skip connections. + Each stage performs feature upsampling and refinement using ConvNeXt blocks. + + # TODO: MANUAL_REVIEW - Multi-scale feature fusion strategy + """ + def __init__( self, num_channels: list[int], @@ -286,6 +404,20 @@ def __init__( self.decoder_stages.append(stage) def forward(self, features: Sequence[Tensor]) -> Tensor: + """Forward pass through hierarchical decoder stages. + + Parameters + ---------- + features : Sequence[torch.Tensor] + List of multi-scale encoder features, ordered from lowest to highest + resolution. First element is the bottleneck feature. + + Returns + ------- + torch.Tensor + Decoded high-resolution features after progressive upsampling and + skip connection fusion through all decoder stages. + """ feat = features[0] # padding features.append(None) @@ -295,6 +427,16 @@ def forward(self, features: Sequence[Tensor]) -> Tensor: class UNeXt2(nn.Module): + """UNeXt2: ConvNeXt-based U-Net for 3D-to-2D-to-3D processing. + + Advanced transformer-inspired U-Net architecture using ConvNeXt backbones + for hierarchical feature extraction. Performs 3D-to-2D projection via stem, + 2D multi-scale processing through ConvNeXt encoder-decoder, and 2D-to-3D + reconstruction via specialized head modules. + + # TODO: MANUAL_REVIEW - ConvNeXt transformer integration patterns + """ + def __init__( self, in_channels: int = 1, @@ -361,6 +503,19 @@ def num_blocks(self) -> int: return 6 def forward(self, x: Tensor) -> Tensor: + """Forward pass through complete UNeXt2 architecture. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (B, C, D, H, W) where D is the input stack depth. + + Returns + ------- + torch.Tensor + Output tensor of shape (B, out_channels, out_stack_depth, H', W') + after 3D-to-2D-to-3D processing through ConvNeXt backbone. + """ x = self.stem(x) x: list = self.encoder_stages(x) x.reverse() diff --git a/viscy/utils/aux_utils.py b/viscy/utils/aux_utils.py index f49137beb..1a6e37ddc 100644 --- a/viscy/utils/aux_utils.py +++ b/viscy/utils/aux_utils.py @@ -1,17 +1,34 @@ -"""Auxiliary utility functions""" +"""Auxiliary utility functions.""" import iohub.ngff as ngff import yaml def _assert_unique_subset(subset, superset, name): - """ - Helper function to allow for clean code: - Throws error if unique elements of subset are not a subset of - unique elements of superset. - - Returns unique elements of subset if given a list. If subset is -1, - returns all unique elements of superset + """Check that unique elements of subset are a subset of superset. + + Helper function to allow for clean code: Throws error if unique elements + of subset are not a subset of unique elements of superset. + + Parameters + ---------- + subset : list or int + Subset to validate. If -1, returns all unique elements of superset. + superset : list + Superset to validate against. + name : str + Name of the parameter being validated (for error messages). + + Returns + ------- + set + Unique elements of subset if given a list. If subset is -1, + returns all unique elements of superset. + + Raises + ------ + AssertionError + If subset is not a subset of superset. """ if subset == -1: subset = superset @@ -33,28 +50,38 @@ def validate_metadata_indices( slice_ids=[], pos_ids=[], ): - """ - Check the availability of indices provided timepoints, channels, positions - and slices for all data, and returns only the available of the specified - indices. + """Check availability of indices for timepoints, channels, positions and slices. + Returns only the available indices from the specified indices. If input ids are None, the indices for that parameter will not be evaluated. If input ids are -1, all indices for that parameter will be returned. - Assumes uniform structure, as such structure is required for HCS compatibility - - :param str zarr_dir: HCS-compatible zarr directory to validate indices against - :param list time_ids: check availability of these timepoints in image - metadata - :param list channel_ids: check availability of these channels in image - metadata - :param list pos_ids: Check availability of positions in zarr_dir - :param list slice_ids: Check availability of z slices in image metadata - - :return dict indices_metadata: All indices found given input - :raise AssertionError: If not all channels, timepoints, positions - or slices are present + Assumes uniform structure, as such structure is required for HCS compatibility. + + Parameters + ---------- + zarr_dir : str + HCS-compatible zarr directory to validate indices against. + time_ids : list, optional + Check availability of these timepoints in image metadata, by default []. + channel_ids : list, optional + Check availability of these channels in image metadata, by default []. + slice_ids : list, optional + Check availability of z slices in image metadata, by default []. + pos_ids : list, optional + Check availability of positions in zarr_dir, by default []. + + Returns + ------- + dict + Dictionary with keys 'time_ids', 'channel_ids', 'slice_ids', 'pos_ids' + containing all indices found given input. + + Raises + ------ + AssertionError + If not all channels, timepoints, positions or slices are present. """ plate = ngff.open_ome_zarr(zarr_dir, layout="hcs", mode="r") position_path, position = next(plate.positions()) @@ -87,13 +114,19 @@ def validate_metadata_indices( def read_config(config_fname): - """Read the config file in yml format + """Read the config file in yml format. - :param str config_fname: fname of config yaml with its full path - :return: dict config: Configuration parameters - """ + Parameters + ---------- + config_fname : str + Filename of config yaml with its full path. - with open(config_fname, "r") as f: + Returns + ------- + dict + Configuration parameters. + """ + with open(config_fname) as f: config = yaml.safe_load(f) return config diff --git a/viscy/utils/cli_utils.py b/viscy/utils/cli_utils.py index 4223e6784..92dcef04b 100644 --- a/viscy/utils/cli_utils.py +++ b/viscy/utils/cli_utils.py @@ -1,3 +1,5 @@ +"""Command-line interface utilities for data processing and visualization.""" + import collections import os import re @@ -8,11 +10,21 @@ def unique_tags(directory): - """ - Returns list of unique nume tags from data directory + """Return list of unique nume tags from data directory. + + Parameters + ---------- + directory : str + Directory containing '.tif' files. - :param str directory: directory containing '.tif' files - TODO: Remove, unused and poorly written + Returns + ------- + dict + Dictionary of unique tags and their counts. + + Notes + ----- + TODO: Remove, unused and poorly written. """ files = [ f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f)) @@ -29,10 +41,11 @@ def unique_tags(directory): return tags -class MultiProcessProgressBar(object): - """ +class MultiProcessProgressBar: + """Progress bar for multi-processed tasks. + Provides the ability to create & update a single progress bar for multi-depth - multi-processed tasks by calling updates on a single object + multi-processed tasks by calling updates on a single object. """ def __init__(self, total_updates): @@ -40,20 +53,32 @@ def __init__(self, total_updates): self.current = 0 def tick(self, process): + """Update progress bar with current process status. + + Parameters + ---------- + process : str + Description of the current process being executed. + """ self.current += 1 show_progress_bar(self.dataloader, self.current, process) def show_progress_bar(dataloader, current, process="training", interval=1): - """ - Utility function to print tensorflow-like progress bar. + """Print TensorFlow-like progress bar for batch processing. Written instead of using tqdm to allow for custom progress bar readouts. - :param iterable dataloader: dataloader currently being processed - :param int current: current index in dataloader - :param str proces: current process being performed - :param int interval: interval at which to update progress bar + Parameters + ---------- + dataloader : iterable + Dataloader currently being processed. + current : int + Current index in dataloader. + process : str, optional + Current process being performed, by default "training". + interval : int, optional + Interval at which to update progress bar, by default 1. """ current += 1 bar_length = 50 @@ -81,16 +106,25 @@ def show_progress_bar(dataloader, current, process="training", interval=1): def save_figure(data, save_folder, name, title=None, vmax=0, ext=".png"): - """ + """Save image data as PNG or JPEG figure. + Saves .png or .jpeg figure of data to folder save_folder under 'name'. - 'data' must be a 3d tensor or numpy array, in channels_first format - - :param numpy.ndarray/torch.tensor data: input image/stack data to save - :param str save_folder: global path to folder where data is saved. - :param str name: name of data, no extension specified - :param str/None title: image title, if none specified, defaults used - :param float vmax: value to normalize figure to, by default uses data max - :param str ext: image save file extension + 'data' must be a 3d tensor or numpy array, in channels_first format. + + Parameters + ---------- + data : numpy.ndarray or torch.Tensor + Input image/stack data to save in channels_first format. + save_folder : str + Global path to folder where data is saved. + name : str + Name of data, no extension specified. + title : str, optional + Image title, if none specified, defaults used, by default None. + vmax : float, optional + Value to normalize figure to, by default 0 (uses data max). + ext : str, optional + Image save file extension, by default ".png". """ assert len(data.shape) == 3, f"'{len(data.shape)}d' data must be 3-dimensional" diff --git a/viscy/utils/image_utils.py b/viscy/utils/image_utils.py index a95691162..2462e3e15 100644 --- a/viscy/utils/image_utils.py +++ b/viscy/utils/image_utils.py @@ -1,14 +1,39 @@ -"""Utility functions for processing images""" +"""Utility functions for processing images.""" import itertools import sys +from typing import Any import numpy as np +from numpy.typing import ArrayLike, NDArray import viscy.utils.normalize as normalize -def im_bit_convert(im, bit=16, norm=False, limit=[]): +def im_bit_convert( + im: ArrayLike, bit: int = 16, norm: bool = False, limit: list[float] = [] +) -> NDArray[Any]: + """Convert image to specified bit depth with optional normalization. + + FIXME: Verify parameter types and exact behavior for edge cases. + + Parameters + ---------- + im : array_like + Input image to convert. + bit : int, optional + Target bit depth (8 or 16), by default 16. + norm : bool, optional + Whether to normalize image to [0, 2^bit-1] range, by default False. + limit : list, optional + Min/max values for normalization. If empty, uses image min/max, + by default []. + + Returns + ------- + np.array + Image converted to specified bit depth. + """ im = im.astype( np.float32, copy=False ) # convert to float32 without making a copy to save memory @@ -29,27 +54,53 @@ def im_bit_convert(im, bit=16, norm=False, limit=[]): return im -def im_adjust(img, tol=1, bit=8): - """ - Stretches contrast of the image and converts to 'bit'-bit. - Useful for weight-maps in masking +def im_adjust(img: ArrayLike, tol: int | float = 1, bit: int = 8) -> NDArray[Any]: + """Stretch contrast of the image and convert to specified bit depth. + + Useful for weight-maps in masking. + + Parameters + ---------- + img : array_like + Input image to adjust. + tol : int or float, optional + Tolerance percentile for contrast stretching, by default 1. + bit : int, optional + Target bit depth, by default 8. + + Returns + ------- + np.array + Contrast-adjusted image in specified bit depth. """ limit = np.percentile(img, [tol, 100 - tol]) im_adjusted = im_bit_convert(img, bit=bit, norm=True, limit=limit.tolist()) return im_adjusted -def grid_sample_pixel_values(im, grid_spacing): - """Sample pixel values in the input image at the grid. Any incomplete - grids (remainders of modulus operation) will be ignored. +def grid_sample_pixel_values( + im: NDArray[Any], grid_spacing: int +) -> tuple[NDArray[Any], NDArray[Any], NDArray[Any]]: + """Sample pixel values in the input image at grid points. - :param np.array im: 2D image - :param int grid_spacing: spacing of the grid - :return int row_ids: row indices of the grids - :return int col_ids: column indices of the grids - :return np.array sample_values: sampled pixel values - """ + Any incomplete grids (remainders of modulus operation) will be ignored. + + Parameters + ---------- + im : np.array + 2D image to sample from. + grid_spacing : int + Spacing of the grid points. + Returns + ------- + row_ids : np.array + Row indices of the grid points. + col_ids : np.array + Column indices of the grid points. + sample_values : np.array + Sampled pixel values at grid points. + """ im_shape = im.shape assert grid_spacing < im_shape[0], "grid spacing larger than image height" assert grid_spacing < im_shape[1], "grid spacing larger than image width" @@ -69,22 +120,38 @@ def grid_sample_pixel_values(im, grid_spacing): def preprocess_image( - im, - hist_clip_limits=None, - is_mask=False, - normalize_im=None, - zscore_mean=None, - zscore_std=None, -): - """ - Do histogram clipping, z score normalization, and potentially binarization. - - :param np.array im: Image (stack) - :param tuple hist_clip_limits: Percentile histogram clipping limits - :param bool is_mask: True if mask - :param str/None normalize_im: Normalization, if any - :param float/None zscore_mean: Data mean - :param float/None zscore_std: Data std + im: ArrayLike, + hist_clip_limits: tuple[float, float] | None = None, + is_mask: bool = False, + normalize_im: str | None = None, + zscore_mean: float | None = None, + zscore_std: float | None = None, +) -> NDArray[Any]: + """Preprocess image with histogram clipping, z-score normalization, and binarization. + + Performs histogram clipping, z-score normalization, and potentially binarization + depending on the input parameters. + + Parameters + ---------- + im : np.array + Input image or image stack. + hist_clip_limits : tuple, optional + Percentile histogram clipping limits (min_percentile, max_percentile), + by default None. + is_mask : bool, optional + True if input is a mask (will be binarized), by default False. + normalize_im : str, optional + Normalization method to apply, by default None. + zscore_mean : float, optional + Precomputed mean for z-score normalization, by default None. + zscore_std : float, optional + Precomputed standard deviation for z-score normalization, by default None. + + Returns + ------- + np.array + Preprocessed image. """ # remove singular dimension for 3D images if len(im.shape) > 3: diff --git a/viscy/utils/log_images.py b/viscy/utils/log_images.py index 3949f93fb..77bee649d 100644 --- a/viscy/utils/log_images.py +++ b/viscy/utils/log_images.py @@ -1,6 +1,6 @@ """Logging example images during training.""" -from typing import Sequence +from collections.abc import Sequence import numpy as np from matplotlib.pyplot import get_cmap diff --git a/viscy/utils/logging.py b/viscy/utils/logging.py index 5bdeac90b..3aa046863 100644 --- a/viscy/utils/logging.py +++ b/viscy/utils/logging.py @@ -1,6 +1,7 @@ import datetime import os import time +from typing import Any import torch @@ -8,11 +9,14 @@ from viscy.utils.normalize import hist_clipping -def log_feature(feature_map, name, log_save_folder, debug_mode): +def log_feature( + feature_map: torch.Tensor, name: str, log_save_folder: str, debug_mode: bool +) -> None: """ - If self.debug_mode, creates a visual of the given feature map, and saves it at - 'log_save_folder' - If no log_save_folder specified, saves relative to working directory with timestamp. + Create visual feature map logs for debugging deep learning models. + + If debug_mode is enabled, creates a visual of the given feature map and saves it at + 'log_save_folder'. If no log_save_folder specified, saves relative to working directory with timestamp. Currently only saving in working directory is supported. This is meant to be an analysis tool, @@ -46,15 +50,56 @@ def log_feature(feature_map, name, log_save_folder, debug_mode): class FeatureLogger: + """ + Logger for visualizing neural network feature maps during training and debugging. + + This utility class provides comprehensive feature map visualization capabilities + for monitoring convolutional neural network activations. It supports both + individual channel visualization and grid-based multi-channel displays, + with flexible normalization and spatial dimension handling. + + The logger is designed for debugging deep learning models by capturing + intermediate layer activations and saving them as organized image files. + It handles multi-dimensional tensors commonly found in computer vision + tasks, including 2D/3D spatial dimensions with batch and channel axes. + + Attributes + ---------- + save_folder : str + Directory path for saving visualization outputs + spatial_dims : int + Number of spatial dimensions in feature tensors (2D or 3D) + full_batch : bool + Whether to log all samples in batch or just the first + save_as_grid : bool + Whether to arrange channels in a grid layout + grid_width : int + Number of columns in grid visualization + normalize_by_grid : bool + Whether to normalize intensities across entire grid + + Examples + -------- + >>> logger = FeatureLogger( + ... save_folder="./feature_logs", + ... spatial_dims=3, + ... save_as_grid=True, + ... grid_width=8, + ... ) + >>> logger.log_feature_map( + ... conv_features, "conv1_activations", dim_names=["batch", "channels"] + ... ) + """ + def __init__( self, - save_folder, - spatial_dims=3, - full_batch=False, - save_as_grid=True, - grid_width=0, - normalize_by_grid=False, - ): + save_folder: str, + spatial_dims: int = 3, + full_batch: bool = False, + save_as_grid: bool = True, + grid_width: int = 0, + normalize_by_grid: bool = False, + ) -> None: """ Logger object for handling logging feature maps inside network architectures. @@ -86,13 +131,14 @@ def __init__( def log_feature_map( self, - feature_map, - feature_name, - dim_names=[], - vmax=0, - ): + feature_map: torch.Tensor, + feature_name: str, + dim_names: list[str] | None = None, + vmax: float = 0, + ) -> None: """ - Creates a log of figures the given feature map tensor at 'save_folder'. + Create a log of figures for the given feature map tensor at 'save_folder'. + Log is saved as images of feature maps in nested directory tree. By default _assumes that batch dimension is the first dimension_, and @@ -112,7 +158,7 @@ def log_feature_map( # handle dim names num_dims = len(feature_map.shape) - if len(dim_names) == 0: + if dim_names is None: dim_names = ["dim_" + str(i) for i in range(len(num_dims))] else: assert len(dim_names) + self.spatial_dims == num_dims, ( @@ -132,11 +178,11 @@ def log_feature_map( def map_feature_dims( self, - feature_map, - save_as_grid, - vmax=0, - depth=0, - ): + feature_map: torch.Tensor, + save_as_grid: bool, + vmax: float = 0, + depth: int = 0, + ) -> None: """ Recursive directory creation for organizing feature map logs @@ -149,7 +195,6 @@ def map_feature_dims( :param float vmax: maximum intensity to normalize figures by :param int depth: recursion counter. depth in dimensions """ - for i in range(feature_map.shape[0]): if len(feature_map.shape) == 3: # individual saving @@ -257,11 +302,18 @@ def map_feature_dims( break return - def interleave_bars(self, arrays, axis, pixel_width=3, value=0): + def interleave_bars( + self, + arrays: list[torch.Tensor], + axis: int, + pixel_width: int = 3, + value: float = 0, + ) -> list[torch.Tensor]: """ + Interleave separator bars between tensors to improve grid visualization. + Takes list of 2d torch tensors and interleaves bars to improve - grid visualization quality. - Assumes arrays are all of the same shape. + grid visualization quality. Assumes arrays are all of the same shape. :param list grid_arrays: list of tensors to place bars between :param int axis: axis on which to interleave bars (0 or 1) diff --git a/viscy/utils/masks.py b/viscy/utils/masks.py index a0881fa02..7366fb53e 100644 --- a/viscy/utils/masks.py +++ b/viscy/utils/masks.py @@ -1,5 +1,8 @@ +from typing import Any + import numpy as np import scipy.ndimage as ndimage +from numpy.typing import NDArray from scipy.ndimage import binary_fill_holes from skimage.filters import gaussian, laplace, threshold_otsu from skimage.morphology import ( @@ -11,14 +14,24 @@ ) -def create_otsu_mask(input_image, sigma=0.6): - """Create a binary mask using morphological operations - :param np.array input_image: generate masks from this 3D image - :param float sigma: Gaussian blur standard deviation, - increase in value increases blur - :return: volume mask of input_image, 3D np.array - """ +def create_otsu_mask( + input_image: NDArray[Any], sigma: float = 0.6 +) -> NDArray[np.bool_]: + """Create a binary mask using Otsu thresholding and morphological operations. + Parameters + ---------- + input_image : np.array + Generate masks from this 3D image. + sigma : float, optional + Gaussian blur standard deviation, increase in value increases blur, + by default 0.6. + + Returns + ------- + np.array + Volume mask of input_image, 3D binary array. + """ input_sz = input_image.shape mid_slice_id = input_sz[0] // 2 @@ -28,20 +41,36 @@ def create_otsu_mask(input_image, sigma=0.6): return mask -def create_membrane_mask(input_image, str_elem_size=23, sigma=0.4, k_size=3, msize=120): - """Create a binary mask using Laplacian of Gaussian (LOG) feature detection - - :param np.array input_image: generate masks from this image - :param int str_elem_size: size of the laplacian filter - used for contarst enhancement, odd number. - Increase in value increases sensitivity of contrast enhancement - :param float sigma: Gaussian blur standard deviation - :param int k_size: disk/ball size for mask dilation, - ball for 3D and disk for 2D data - :param int msize: size of small objects removed to clean segmentation - :return: mask of input_image, np.array +def create_membrane_mask( + input_image: NDArray[Any], + str_elem_size: int = 23, + sigma: float = 0.4, + k_size: int = 3, + msize: int = 120, +) -> NDArray[np.bool_]: + """Create a binary mask using Laplacian of Gaussian (LOG) feature detection. + + Parameters + ---------- + input_image : np.array + Generate masks from this image. + str_elem_size : int, optional + Size of the laplacian filter used for contrast enhancement, odd number. + Increase in value increases sensitivity of contrast enhancement, + by default 23. + sigma : float, optional + Gaussian blur standard deviation, by default 0.4. + k_size : int, optional + Disk/ball size for mask dilation, ball for 3D and disk for 2D data, + by default 3. + msize : int, optional + Size of small objects removed to clean segmentation, by default 120. + + Returns + ------- + np.array + Binary mask of input_image. """ - input_image_blur = gaussian(input_image, sigma=sigma) input_Lapl = laplace(input_image_blur, ksize=str_elem_size) @@ -61,17 +90,24 @@ def create_membrane_mask(input_image, str_elem_size=23, sigma=0.4, k_size=3, msi return mask -def get_unimodal_threshold(input_image): - """Determines optimal unimodal threshold +def get_unimodal_threshold(input_image: NDArray[Any]) -> float: + """Determine optimal unimodal threshold using Rosin's method. + References + ---------- https://users.cs.cf.ac.uk/Paul.Rosin/resources/papers/unimodal2.pdf https://www.mathworks.com/matlabcentral/fileexchange/45443-rosin-thresholding - :param np.array input_image: generate mask for this image - :return float best_threshold: optimal lower threshold for the foreground - hist - """ + Parameters + ---------- + input_image : np.array + Generate mask for this image. + Returns + ------- + float + Optimal lower threshold for the foreground histogram. + """ hist_counts, bin_edges = np.histogram( input_image, bins=256, @@ -105,17 +141,27 @@ def get_unimodal_threshold(input_image): return best_threshold -def create_unimodal_mask(input_image, str_elem_size=3, sigma=0.6): - """ - Create a mask with unimodal thresholding and morphological operations. - Unimodal thresholding seems to oversegment, erode it by a fraction +def create_unimodal_mask( + input_image: NDArray[Any], str_elem_size: int = 3, sigma: float = 0.6 +) -> NDArray[np.bool_]: + """Create a mask with unimodal thresholding and morphological operations. - :param np.array input_image: generate masks from this image - :param int str_elem_size: size of the structuring element. typically 3, 5 - :param float sigma: gaussian blur standard deviation - :return mask of input_image, np.array - """ + Unimodal thresholding seems to oversegment, erode it by a fraction. + Parameters + ---------- + input_image : np.array + Generate masks from this image. + str_elem_size : int, optional + Size of the structuring element, typically 3 or 5, by default 3. + sigma : float, optional + Gaussian blur standard deviation, by default 0.6. + + Returns + ------- + np.array + Binary mask of input_image. + """ input_image = gaussian(input_image, sigma=sigma) if np.min(input_image) == np.max(input_image): @@ -133,21 +179,31 @@ def create_unimodal_mask(input_image, str_elem_size=3, sigma=0.6): return mask -def get_unet_border_weight_map(annotation, w0=10, sigma=5): - """ - Return weight map for borders as specified in UNet paper - :param annotation A 2D array of shape (image_height, image_width) - contains annotation with each class labeled as an integer. - :param w0 multiplier to the exponential distance loss - default 10 as mentioned in UNet paper - :param sigma standard deviation in the exponential distance term - e^(-d1 + d2) ** 2 / 2 (sigma ^ 2) - default 5 as mentioned in UNet paper - :return weight mapt for borders as specified in UNet - - TODO: Calculate boundaries directly and calculate distance - from boundary of cells to another - Note: The below method only works for UNet Segmentation only +def get_unet_border_weight_map( + annotation: NDArray[Any], w0: int = 10, sigma: int = 5 +) -> NDArray[np.float64]: + """Return weight map for borders as specified in U-Net paper. + + TODO: Calculate boundaries directly and calculate distance from boundary + of cells to another. Note: The below method only works for UNet Segmentation only. + + Parameters + ---------- + annotation : np.array + A 2D array of shape (image_height, image_width) containing annotation + with each class labeled as an integer. + w0 : int, optional + Multiplier to the exponential distance loss, default 10 as mentioned + in UNet paper, by default 10. + sigma : int, optional + Standard deviation in the exponential distance term + e^(-d1 + d2) ** 2 / 2 (sigma ^ 2), default 5 as mentioned in UNet paper, + by default 5. + + Returns + ------- + np.array + Weight map for borders as specified in U-Net paper. """ # if there is only one label, zero return the array as is if np.sum(annotation) == 0: @@ -160,7 +216,7 @@ def get_unet_border_weight_map(annotation, w0=10, sigma=5): assert annotation.dtype in [ np.uint8, np.uint16, - ], "Expected data type uint, it is {}".format(annotation.dtype) + ], f"Expected data type uint, it is {annotation.dtype}" # cells instances for distance computation # 4 connected i.e default (cross-shaped) diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index 961b66967..22a28754e 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -10,8 +10,9 @@ def write_meta_field(position: ngff.Position, metadata, field_name, subfield_name): - """ - Writes 'metadata' to position's plate-level or FOV level .zattrs metadata by either + """Write 'metadata' to position's plate-level or FOV level .zattrs metadata. + + Write 'metadata' to position's plate-level or FOV level .zattrs metadata by either creating a new field (field_name) according to 'metadata', or updating the metadata to an existing field if found, or concatenating the metadata from different channels. @@ -51,7 +52,8 @@ def generate_normalization_metadata( channel_ids=-1, grid_spacing=32, ): - """ + """Generate pixel intensity metadata for on-the-fly normalization. + Generate pixel intensity metadata to be later used in on-the-fly normalization during training and inference. Sampling is used for efficient estimation of median and interquartile range for intensity values on both a dataset and field-of-view @@ -78,147 +80,143 @@ def generate_normalization_metadata( plate = ngff.open_ome_zarr(zarr_dir, mode="r+") position_map = list(plate.positions()) + # Prepare parameters for multiprocessing + zarr_dir_path = os.path.dirname(os.path.dirname(zarr_dir)) + + # Get channels to process if channel_ids == -1: - channel_ids = range(len(plate.channel_names)) - elif isinstance(channel_ids, int): + # Get channel IDs from first position + first_position = position_map[0][1] + first_images = list(first_position.images()) + first_image = first_images[0][1] + # shape is (t, c, z, y, x) + channel_ids = list(range(first_image.data.shape[1])) + + if isinstance(channel_ids, int): channel_ids = [channel_ids] - # get arguments for multiprocessed grid sampling - mp_grid_sampler_args = [] - for _, position in position_map: - mp_grid_sampler_args.append([position, grid_spacing]) - - # sample values and use them to get normalization statistics - for i, channel in enumerate(channel_ids): - show_progress_bar( - dataloader=channel_ids, - current=i, - process="sampling channel values", - ) - - channel_name = plate.channel_names[channel] - this_channels_args = tuple([args + [channel] for args in mp_grid_sampler_args]) - - # NOTE: Doing sequential mp with pool execution creates synchronization - # points between each step. This could be detrimental to performance - positions, fov_sample_values = mp_utils.mp_sample_im_pixels( - this_channels_args, num_workers - ) - dataset_sample_values = np.concatenate( - [arr.flatten() for arr in fov_sample_values] - ) - fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers) - dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values) - - dataset_statistics = { - "dataset_statistics": dataset_level_statistics, - } - - write_meta_field( - position=plate, - metadata=dataset_statistics, - field_name="normalization", - subfield_name=channel_name, - ) - - for j, pos in enumerate(positions): - show_progress_bar( - dataloader=position_map, - current=j, - process=f"calculating channel statistics {channel}/{list(channel_ids)}", - ) - position_statistics = dataset_statistics | { - "fov_statistics": fov_level_statistics[j], + # Prepare parameters for each position and channel + params_list = [] + for position_idx, (position_key, position) in enumerate(position_map): + for channel_id in channel_ids: + params = { + "zarr_dir": zarr_dir, + "position_key": position_key, + "channel_id": channel_id, + "grid_spacing": grid_spacing, + } + params_list.append(params) + + # Use multiprocessing to compute normalization statistics + progress_bar = show_progress_bar() + if num_workers > 1: + with mp_utils.get_context("spawn").Pool(num_workers) as pool: + results = pool.map(mp_utils.normalize_meta_worker, params_list) + progress_bar.update(len(params_list)) + else: + results = [] + for params in params_list: + result = mp_utils.normalize_meta_worker(params) + results.append(result) + progress_bar.update(1) + + progress_bar.close() + + # Aggregate results and write to metadata + all_dataset_stats = {} + for result in results: + if result is not None: + position_key, channel_id, dataset_stats, fov_stats = result + + if channel_id not in all_dataset_stats: + all_dataset_stats[channel_id] = [] + all_dataset_stats[channel_id].append(dataset_stats) + + # Calculate dataset-level statistics + final_dataset_stats = {} + for channel_id, stats_list in all_dataset_stats.items(): + if stats_list: + # Aggregate median and IQR across all positions + medians = [stats["median"] for stats in stats_list if "median" in stats] + iqrs = [stats["iqr"] for stats in stats_list if "iqr" in stats] + + if medians and iqrs: + final_dataset_stats[channel_id] = { + "median": np.median(medians), + "iqr": np.median(iqrs), + } + + # Write metadata to each position + for result in results: + if result is not None: + position_key, channel_id, dataset_stats, fov_stats = result + + # Get position object + position = dict(plate.positions())[position_key] + + # Prepare metadata + metadata = { + "dataset_statistics": final_dataset_stats.get(channel_id, {}), + "fov_statistics": fov_stats, } + # Write metadata write_meta_field( - position=pos, - metadata=position_statistics, + position=position, + metadata=metadata, field_name="normalization", - subfield_name=channel_name, + subfield_name=str(channel_id), ) - plate.close() + print(f"Generated normalization metadata for {len(channel_ids)} channels") + print(f"Dataset-level statistics: {final_dataset_stats}") -def compute_zscore_params( - frames_meta, ints_meta, input_dir, normalize_im, min_fraction=0.99 -): - """ - Get zscore median and interquartile range - - :param pd.DataFrame frames_meta: Dataframe containing all metadata - :param pd.DataFrame ints_meta: Metadata containing intensity statistics - each z-slice and foreground fraction for masks - :param str input_dir: Directory containing images - :param None or str normalize_im: normalization scheme for input images - :param float min_fraction: Minimum foreground fraction (in case of masks) - for computing intensity statistics. - - :return pd.DataFrame frames_meta: Dataframe containing all metadata - :return pd.DataFrame ints_meta: Metadata containing intensity statistics - each z-slice + +def compute_normalization_stats(image_data, grid_spacing=32): + """Compute normalization statistics from image data using grid sampling. + + :param np.array image_data: 3D or 4D image array (z, y, x) or (t, z, y, x) + :param int grid_spacing: spacing between grid points for sampling + :return dict: dictionary with median and IQR statistics """ + # Handle different input shapes + if image_data.ndim == 4: + # Assume (t, z, y, x) and take first timepoint + image_data = image_data[0] + + if image_data.ndim == 3: + # Assume (z, y, x) and use middle z-slice if available + if image_data.shape[0] > 1: + z_mid = image_data.shape[0] // 2 + image_data = image_data[z_mid] + else: + image_data = image_data[0] - assert normalize_im in [ - None, - "slice", - "volume", - "dataset", - ], 'normalize_im must be None or "slice" or "volume" or "dataset"' - - if normalize_im is None: - # No normalization - frames_meta["zscore_median"] = 0 - frames_meta["zscore_iqr"] = 1 - return frames_meta - elif normalize_im == "dataset": - agg_cols = ["time_idx", "channel_idx", "dir_name"] - elif normalize_im == "volume": - agg_cols = ["time_idx", "channel_idx", "dir_name", "pos_idx"] - else: - agg_cols = ["time_idx", "channel_idx", "dir_name", "pos_idx", "slice_idx"] - # median and inter-quartile range are more robust than mean and std - ints_meta_sub = ints_meta[ints_meta["fg_frac"] >= min_fraction] - ints_agg_median = ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).median() - ints_agg_hq = ( - ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).quantile(0.75) - ) - ints_agg_lq = ( - ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).quantile(0.25) - ) - ints_agg = ints_agg_median - ints_agg.columns = ["zscore_median"] - ints_agg["zscore_iqr"] = ints_agg_hq["intensity"] - ints_agg_lq["intensity"] - ints_agg.reset_index(inplace=True) - - cols_to_merge = frames_meta.columns[ - [col not in ["zscore_median", "zscore_iqr"] for col in frames_meta.columns] - ] - frames_meta = pd.merge( - frames_meta[cols_to_merge], - ints_agg, - how="left", - on=agg_cols, - ) - if frames_meta["zscore_median"].isnull().values.any(): - raise ValueError( - "Found NaN in normalization parameters. \ - min_fraction might be too low or images might be corrupted." - ) - frames_meta_filename = os.path.join(input_dir, "frames_meta.csv") - frames_meta.to_csv(frames_meta_filename, sep=",") - - cols_to_merge = ints_meta.columns[ - [col not in ["zscore_median", "zscore_iqr"] for col in ints_meta.columns] - ] - ints_meta = pd.merge( - ints_meta[cols_to_merge], - ints_agg, - how="left", - on=agg_cols, - ) - ints_meta["intensity_norm"] = ( - ints_meta["intensity"] - ints_meta["zscore_median"] - ) / (ints_meta["zscore_iqr"] + sys.float_info.epsilon) - - return frames_meta, ints_meta + # Now image_data should be 2D (y, x) + if image_data.ndim != 2: + raise ValueError(f"Expected 2D image after processing, got {image_data.ndim}D") + + # Create sampling grid + y_indices = np.arange(0, image_data.shape[0], grid_spacing) + x_indices = np.arange(0, image_data.shape[1], grid_spacing) + + # Sample values at grid points + sampled_values = image_data[np.ix_(y_indices, x_indices)].flatten() + + # Remove any NaN or infinite values + sampled_values = sampled_values[np.isfinite(sampled_values)] + + if len(sampled_values) == 0: + return {"median": 0.0, "iqr": 1.0} + + # Compute statistics + median = np.median(sampled_values) + q25 = np.percentile(sampled_values, 25) + q75 = np.percentile(sampled_values, 75) + iqr = q75 - q25 + + # Avoid zero IQR + if iqr == 0: + iqr = 1.0 + + return {"median": float(median), "iqr": float(iqr)} diff --git a/viscy/utils/mp_utils.py b/viscy/utils/mp_utils.py index ee46f7ea9..ce04e8093 100644 --- a/viscy/utils/mp_utils.py +++ b/viscy/utils/mp_utils.py @@ -1,19 +1,34 @@ +from collections.abc import Callable from concurrent.futures import ProcessPoolExecutor +from typing import Any import iohub.ngff as ngff import numpy as np import scipy.stats +import zarr import viscy.utils.image_utils as image_utils import viscy.utils.masks as mask_utils -def mp_wrapper(fn, fn_args, workers): - """Create and save masks with multiprocessing - - :param list of tuple fn_args: list with tuples of function arguments - :param int workers: max number of workers - :return: list of returned dicts from create_save_mask +def mp_wrapper( + fn: Callable[..., Any], fn_args: list[tuple[Any, ...]], workers: int +) -> list[Any]: + """Create and save masks with multiprocessing. + + Parameters + ---------- + fn : callable + Function to be applied with multiprocessing. + fn_args : list of tuple + List with tuples of function arguments. + workers : int + Max number of workers. + + Returns + ------- + list + List of returned dicts from create_save_mask. """ with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions @@ -21,13 +36,22 @@ def mp_wrapper(fn, fn_args, workers): return list(res) -def mp_create_and_write_mask(fn_args, workers): - """Create and save masks with multiprocessing. For argument parameters - see mp_utils.create_and_write_mask. +def mp_create_and_write_mask(fn_args: list[tuple[Any, ...]], workers: int) -> list[Any]: + """Create and save masks with multiprocessing. - :param list of tuple fn_args: list with tuples of function arguments - :param int workers: max number of workers - :return: list of returned dicts from create_save_mask + For argument parameters see mp_utils.create_and_write_mask. + + Parameters + ---------- + fn_args : list of tuple + List with tuples of function arguments. + workers : int + Max number of workers. + + Returns + ------- + list + List of returned dicts from create_save_mask. """ with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions @@ -37,29 +61,35 @@ def mp_create_and_write_mask(fn_args, workers): def add_channel( position: ngff.Position, - new_channel_array, - new_channel_name, - overwrite_ok=False, -): - """ - Adds a channels to the data array at position "position". Note that there is - only one 'tracked' data array in current HCS spec at each position. Also - updates the 'omero' channel-tracking metadata to track the new channel. + new_channel_array: np.ndarray, + new_channel_name: str, + overwrite_ok: bool = False, +) -> None: + """Add a channel to the data array at specified position. + + Note that there is only one 'tracked' data array in current HCS spec at each position. + Also updates the 'omero' channel-tracking metadata to track the new channel. The 'new_channel_array' must match the dimensions of the current array in - all positions but the channel position (1) and have the same datatype + all positions but the channel position (1) and have the same datatype. Note: to maintain HCS compatibility of the zarr store, all positions (wells) must maintain arrays with congruent channels. That is, if you add a channel to one position of an HCS compatible zarr store, an additional channel must be added to every position in that store to maintain HCS compatibility. - :param Position zarr_dir: NGFF position node object - :param np.ndarray new_channel_array: array to add as new channel with matching - dimensions (except channel dim) and dtype - :param str new_channel_name: name of new channel - :param bool overwrite_ok: if true, if a channel with the same name as - 'new_channel_name' is found, will overwrite + Parameters + ---------- + position : ngff.Position + NGFF position node object. + new_channel_array : np.ndarray + Array to add as new channel with matching dimensions (except channel dim) + and dtype. + new_channel_name : str + Name of new channel. + overwrite_ok : bool, optional + If true, if a channel with the same name as 'new_channel_name' is found, + will overwrite, by default False. """ assert len(new_channel_array.shape) == len(position.data.shape) - 1, ( "New channel array must match all dimensions of the position array, " @@ -82,20 +112,18 @@ def add_channel( def create_and_write_mask( position: ngff.Position, - time_indices, - channel_indices, - structure_elem_radius, - mask_type, - mask_name, - verbose=False, -): - # TODO: rewrite docstring - """ - Create mask *for all depth slices* at each time and channel index specified - in this position, and save them both as an additional channel in the data array - of the given zarr store and a separate 'untracked' array with specified name. - If output_channel_index is specified as an existing channel index, will overwrite - this channel instead. + time_indices: list[int], + channel_indices: list[int], + structure_elem_radius: int, + mask_type: str, + mask_name: str, + verbose: bool = False, +) -> None: + """Create mask for all depth slices at specified time and channel indices. + + Creates masks at each time and channel index specified in this position, + and saves them both as an additional channel in the data array of the given + zarr store and a separate 'untracked' array with specified name. Saves custom metadata related to the mask creation in the well-level .zattrs in the 'mask' field. @@ -105,24 +133,25 @@ def create_and_write_mask( a timepoint-position basis. That is, it will be recorded as an average foreground fraction over all slices in any given timepoint. - - :param str zarr_dir: directory to HCS compatible zarr store for usage - :param str position_path: path within store to position to generate masks for - :param list time_indices: list of time indices for mask generation, - if an index is skipped over, will populate with - zeros - :param list channel_indices: list of channel indices for mask generation, - if more than 1 channel specified, masks from all - channels are aggregated - :param int structure_elem_radius: size of structuring element used for binary - opening. str_elem: disk or ball - :param str mask_type: thresholding type used for masking or str to map to - masking function - :param str mask_name: name under which to save untracked copy of mask in - position - :param bool verbose: whether this process should send updates to stdout + Parameters + ---------- + position : ngff.Position + NGFF position node object. + time_indices : list + List of time indices for mask generation. If an index is skipped over, + will populate with zeros. + channel_indices : list + List of channel indices for mask generation. If more than 1 channel + specified, masks from all channels are aggregated. + structure_elem_radius : int + Size of structuring element used for binary opening. str_elem: disk or ball. + mask_type : str + Thresholding type used for masking or str to map to masking function. + mask_name : str + Name under which to save untracked copy of mask in position. + verbose : bool, optional + Whether this process should send updates to stdout, by default False. """ - shape = position.data.shape position_masks_shape = tuple([shape[0], len(channel_indices), *shape[2:]]) @@ -195,25 +224,35 @@ def create_and_write_mask( def get_mask_slice( - position_zarr, - time_index, - channel_index, - mask_type, - structure_elem_radius, -): - """ + position_zarr: zarr.Array, + time_index: int, + channel_index: int, + mask_type: str, + structure_elem_radius: int, +) -> np.ndarray: + """Compute mask for a single image slice. + Given a set of indices, mask type, and structuring element, pulls an image slice from the given zarr array, computes the requested mask and returns. - :param zarr.Array position_zarr: zarr array of the desired position - :param time_index: see name - :param channel_index: see name - :param mask_type: see name, - options are {otsu, unimodal, mem_detection, borders_weight_loss_map} - :param int structure_elem_radius: creation radius for the structuring - element - :return np.ndarray mask: 2d mask for this slice + Parameters + ---------- + position_zarr : zarr.Array + Zarr array of the desired position. + time_index : int + Time index for the slice. + channel_index : int + Channel index for the slice. + mask_type : str + Mask type, options are {otsu, unimodal, mem_detection, borders_weight_loss_map}. + structure_elem_radius : int + Creation radius for the structuring element. + + Returns + ------- + np.ndarray + 2D mask for this slice. """ # read and correct/preprocess slice im = position_zarr[time_index, channel_index] @@ -237,9 +276,9 @@ def get_mask_slice( return mask -def mp_get_val_stats(fn_args, workers): +def mp_get_val_stats(fn_args: list[Any], workers: int) -> list[dict[str, float]]: """ - Computes statistics of numpy arrays with multiprocessing + Compute statistics of numpy arrays with multiprocessing :param list of tuple fn_args: list with tuples of function arguments :param int workers: max number of workers @@ -251,16 +290,22 @@ def mp_get_val_stats(fn_args, workers): return list(res) -def get_val_stats(sample_values): - """ +def get_val_stats(sample_values: list[float]) -> dict[str, float]: + """Compute statistics of a numpy array. + Computes the statistics of a numpy array and returns a dictionary of metadata corresponding to input sample values. - :param list(float) sample_values: List of sample values at respective - indices - :return dict meta_row: Dict with intensity data for image - """ + Parameters + ---------- + sample_values : list of float + List of sample values at respective indices. + Returns + ------- + dict + Dictionary with intensity data for image. + """ meta_row = { "mean": float(np.nanmean(sample_values)), "std": float(np.nanstd(sample_values)), @@ -270,14 +315,15 @@ def get_val_stats(sample_values): return meta_row -def mp_sample_im_pixels(fn_args, workers): +def mp_sample_im_pixels( + fn_args: list[tuple[Any, ...]], workers: int +) -> list[list[Any]]: """Read and computes statistics of images with multiprocessing :param list of tuple fn_args: list with tuples of function arguments :param int workers: max number of workers :return: list of paths and corresponding returned df from get_im_stats """ - with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions res = ex.map(sample_im_pixels, *zip(*fn_args)) @@ -286,23 +332,30 @@ def mp_sample_im_pixels(fn_args, workers): def sample_im_pixels( position: ngff.Position, - grid_spacing, - channel, -): + grid_spacing: int, + channel: int, +) -> tuple[ngff.Position, np.ndarray]: # TODO move out of mp utils into normalization utils - """ - Read and computes statistics of images for each point in a grid. + """Read and compute statistics of images for each point in a grid. + Grid spacing determines distance in pixels between grid points - for rows and cols. - By default, samples from every time position and every z-depth, and - assumes that the data in the zarr store is stored in [T,C,Z,Y,X] format, + for rows and cols. By default, samples from every time position and every z-depth, + and assumes that the data in the zarr store is stored in [T,C,Z,Y,X] format, for time, channel, z, y, x. - :param Position zarr_dir: NGFF position node object - :param int grid_spacing: spacing of sampling grid in x and y - :param int channel: channel to sample from - - :return list meta_rows: Dicts with intensity data for each grid point + Parameters + ---------- + position : ngff.Position + NGFF position node object. + grid_spacing : int + Spacing of sampling grid in x and y. + channel : int + Channel to sample from. + + Returns + ------- + list + Dicts with intensity data for each grid point. """ image_zarr = position.data diff --git a/viscy/utils/normalize.py b/viscy/utils/normalize.py index 73753acb7..6d12baba3 100644 --- a/viscy/utils/normalize.py +++ b/viscy/utils/normalize.py @@ -1,19 +1,33 @@ -"""Image normalization related functions""" +"""Image normalization related functions.""" import sys +from typing import Any import numpy as np +from numpy.typing import ArrayLike, NDArray from skimage.exposure import equalize_adapthist -def zscore(input_image, im_mean=None, im_std=None): - """ - Performs z-score normalization. Adds epsilon in denominator for robustness +def zscore( + input_image: ArrayLike, im_mean: float | None = None, im_std: float | None = None +) -> NDArray[Any]: + """Perform z-score normalization. + + Adds epsilon in denominator for robustness. + + Parameters + ---------- + input_image : np.array + Input image for intensity normalization. + im_mean : float, optional + Image mean, by default None. + im_std : float, optional + Image std, by default None. - :param np.array input_image: input image for intensity normalization - :param float/None im_mean: Image mean - :param float/None im_std: Image std - :return np.array norm_img: z score normalized image + Returns + ------- + np.array + Z-score normalized image. """ if not im_mean: im_mean = np.nanmean(input_image) @@ -23,50 +37,87 @@ def zscore(input_image, im_mean=None, im_std=None): return norm_img -def unzscore(im_norm, zscore_median, zscore_iqr): - """ - Revert z-score normalization applied during preprocessing. Necessary - before computing SSIM +def unzscore( + im_norm: ArrayLike, zscore_median: float, zscore_iqr: float +) -> NDArray[Any]: + """Revert z-score normalization applied during preprocessing. + + Necessary before computing SSIM. - :param im_norm: Normalized image for un-zscore - :param zscore_median: Image median - :param zscore_iqr: Image interquartile range - :return im: image at its original scale + Parameters + ---------- + im_norm : array_like + Normalized image for un-zscore. + zscore_median : float + Image median. + zscore_iqr : float + Image interquartile range. + + Returns + ------- + array_like + Image at its original scale. """ im = im_norm * (zscore_iqr + sys.float_info.epsilon) + zscore_median return im -def hist_clipping(input_image, min_percentile=2, max_percentile=98): - """Clips and rescales histogram from min to max intensity percentiles - - rescale_intensity with input check - - :param np.array input_image: input image for intensity normalization - :param int/float min_percentile: min intensity percentile - :param int/flaot max_percentile: max intensity percentile - :return: np.float, intensity clipped and rescaled image +def hist_clipping( + input_image: ArrayLike, + min_percentile: int | float = 2, + max_percentile: int | float = 98, +) -> NDArray[Any]: + """Clip and rescale histogram from min to max intensity percentiles. + + rescale_intensity with input check. + + Parameters + ---------- + input_image : np.array + Input image for intensity normalization. + min_percentile : int or float, optional + Min intensity percentile, by default 2. + max_percentile : int or float, optional + Max intensity percentile, by default 98. + + Returns + ------- + np.array + Intensity clipped and rescaled image. """ - assert (min_percentile < max_percentile) and max_percentile <= 100 pmin, pmax = np.percentile(input_image, (min_percentile, max_percentile)) hist_clipped_image = np.clip(input_image, pmin, pmax) return hist_clipped_image -def hist_adapteq_2D(input_image, kernel_size=None, clip_limit=None): - """CLAHE on 2D images +def hist_adapteq_2D( + input_image: NDArray[Any], + kernel_size: int | list[int] | tuple[int, ...] | None = None, + clip_limit: float | None = None, +) -> NDArray[Any]: + """Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) on 2D images. skimage.exposure.equalize_adapthist works only for 2D. Extend to 3D or use - openCV? Not ideal, as it enhances noise in homogeneous areas - - :param np.array input_image: input image for intensity normalization - :param int/list kernel_size: Neighbourhood to be used for histogram - equalization. If none, use default of 1/8th image size. - :param float clip_limit: Clipping limit, normalized between 0 and 1 - (higher values give more contrast, ~ max percent of voxels in any - histogram bin, if > this limit, the voxel intensities are redistributed). - if None, default=0.01 + openCV? Not ideal, as it enhances noise in homogeneous areas. + + Parameters + ---------- + input_image : np.array + Input image for intensity normalization. + kernel_size : int or list, optional + Neighbourhood to be used for histogram equalization. If None, use default + of 1/8th image size, by default None. + clip_limit : float, optional + Clipping limit, normalized between 0 and 1 (higher values give more + contrast, ~ max percent of voxels in any histogram bin, if > this limit, + the voxel intensities are redistributed). If None, default=0.01, + by default None. + + Returns + ------- + np.array + Adaptive histogram equalized image. """ nrows, ncols = input_image.shape if kernel_size is not None: @@ -78,9 +129,7 @@ def hist_adapteq_2D(input_image, kernel_size=None, clip_limit=None): raise ValueError("kernel size invalid: not an int / list / tuple") if clip_limit is not None: - assert 0 <= clip_limit <= 1, "Clip limit {} is out of range [0, 1]".format( - clip_limit - ) + assert 0 <= clip_limit <= 1, f"Clip limit {clip_limit} is out of range [0, 1]" adapt_eq_image = equalize_adapthist( input_image, kernel_size=kernel_size, clip_limit=clip_limit From 2dc4c794b59c528446f95e712d39d0df5c9ae7ae Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Wed, 10 Sep 2025 18:24:37 -0700 Subject: [PATCH 02/13] finished second half of docstrings --- viscy/cli.py | 6 +- viscy/representation/classification.py | 30 +- viscy/representation/embedding_writer.py | 20 +- viscy/representation/engine.py | 43 ++- .../evaluation/visualization.py | 4 +- viscy/representation/multi_modal.py | 24 +- viscy/transforms/_gaussian_blur.py | 40 +++ viscy/transforms/_transforms.py | 118 ++++++- viscy/translation/engine.py | 92 +++--- viscy/translation/evaluation.py | 8 +- viscy/translation/evaluation_metrics.py | 101 +++++- viscy/translation/predict_writer.py | 2 +- viscy/unet/networks/Unet25D.py | 10 +- viscy/unet/networks/Unet2D.py | 4 +- viscy/unet/networks/fcmae.py | 294 +++++++++++++----- viscy/unet/networks/layers/ConvBlock2D.py | 79 +++-- viscy/unet/networks/layers/ConvBlock3D.py | 67 ++-- viscy/unet/networks/unext2.py | 120 ++++++- viscy/utils/aux_utils.py | 10 +- viscy/utils/cli_utils.py | 44 ++- viscy/utils/image_utils.py | 22 +- viscy/utils/log_images.py | 13 +- viscy/utils/logging.py | 161 ++++++---- viscy/utils/masks.py | 18 +- viscy/utils/meta_utils.py | 83 +++-- viscy/utils/mp_utils.py | 45 ++- viscy/utils/normalize.py | 28 +- viscy/utils/slurm_utils.py | 3 +- 28 files changed, 1051 insertions(+), 438 deletions(-) diff --git a/viscy/cli.py b/viscy/cli.py index 93f3d118b..65f4137ed 100644 --- a/viscy/cli.py +++ b/viscy/cli.py @@ -8,7 +8,7 @@ import torch from jsonargparse import lazy_instance from lightning.pytorch import LightningDataModule, LightningModule -from lightning.pytorch.cli import LightningCLI +from lightning.pytorch.cli import LightningArgumentParser, LightningCLI from lightning.pytorch.loggers import TensorBoardLogger from viscy.trainer import VisCyTrainer @@ -33,12 +33,12 @@ def subcommands() -> dict[str, set[str]]: subcommands["precompute"] = subcommand_base_args return subcommands - def add_arguments_to_parser(self, parser) -> None: + def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Add default arguments to the Lightning CLI parser. Parameters ---------- - parser + parser : LightningArgumentParser Lightning CLI parser instance to configure. """ parser.set_defaults( diff --git a/viscy/representation/classification.py b/viscy/representation/classification.py index 27aac722b..ae64f2926 100644 --- a/viscy/representation/classification.py +++ b/viscy/representation/classification.py @@ -20,6 +20,13 @@ class ClassificationPredictionWriter(BasePredictionWriter): """ def __init__(self, output_path: Path) -> None: + """Initialize the prediction writer. + + Parameters + ---------- + output_path : Path + Path to the output CSV file. + """ super().__init__("epoch") if Path(output_path).exists(): raise FileExistsError(f"Output path {output_path} already exists.") @@ -36,13 +43,13 @@ def write_on_epoch_end( Parameters ---------- - trainer : lightning.Trainer + trainer : Trainer PyTorch Lightning trainer instance. - pl_module : lightning.LightningModule + pl_module : LightningModule Lightning module being trained. - predictions : list + predictions : list[dict[str, Any]] List of prediction dictionaries from all batches. - batch_indices : list + batch_indices : list[int] Indices of batches processed during prediction. """ all_predictions = [] @@ -68,6 +75,17 @@ def __init__( lr: float | None, loss: nn.Module | None = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1.0)), ) -> None: + """Initialize the classification module. + + Parameters + ---------- + encoder : ContrastiveEncoder + Contrastive encoder model. + lr : float | None + Learning rate. + loss : nn.Module | None + Loss function. By default, BCEWithLogitsLoss with positive weight of 1.0. + """ super().__init__() self.stem = encoder.stem self.backbone = encoder.encoder @@ -166,7 +184,7 @@ def predict_step( batch: tuple[torch.Tensor, torch.Tensor, dict[str, Any]], batch_idx: int, dataloader_idx: int | None = None, - ) -> dict[str, Any]: + ) -> dict[str, torch.Tensor]: """Execute prediction step with sigmoid activation for probabilities. Parameters @@ -180,7 +198,7 @@ def predict_step( Returns ------- - dict + dict[str, torch.Tensor] Dictionary containing indices, labels, and sigmoid probabilities. """ x, y, indices = batch diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index d4b1e9f62..0e4db683c 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import torch +import xarray as xr from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import NDArray @@ -22,7 +23,7 @@ _logger = logging.getLogger("lightning.pytorch") -def read_embedding_dataset(path: Path) -> Dataset: +def read_embedding_dataset(path: Path) -> xr.Dataset: """Read the embedding dataset written by the EmbeddingWriter callback. Supports both legacy datasets (without x/y coordinates) and new datasets. @@ -34,7 +35,7 @@ def read_embedding_dataset(path: Path) -> Dataset: Returns ------- - Dataset + xr.Dataset Xarray dataset with features and projections. """ dataset = open_zarr(path) @@ -60,8 +61,8 @@ def _move_and_stack_embeddings( def write_embedding_dataset( - output_path: Path, - features: np.ndarray, + output_path: str | Path, + features: NDArray, index_df: pd.DataFrame, projections: np.ndarray | None = None, umap_kwargs: dict[str, Any] | None = None, @@ -74,9 +75,9 @@ def write_embedding_dataset( Parameters ---------- - output_path : Path + output_path : str | Path Path to the zarr store. - features : np.ndarray + features : NDArray Array of shape (n_samples, n_features) containing the embeddings. index_df : pd.DataFrame DataFrame containing the index information for each embedding. @@ -191,11 +192,12 @@ class EmbeddingWriter(BasePredictionWriter): Path to the zarr store. write_interval : Literal["batch", "epoch", "batch_and_epoch"], optional When to write the embeddings, by default 'epoch'. - umap_kwargs : dict, optional + umap_kwargs : dict[str, Any], optional Keyword arguments passed to UMAP, by default None (i.e. UMAP is not computed). - phate_kwargs : dict, optional + phate_kwargs : dict[str, Any], optional Keyword arguments passed to PHATE, by default PHATE is computed with default parameters. - pca_kwargs : dict, optional + Default configuration passed is: {"knn": 5, "decay": 40, "n_jobs": -1, "random_state": 42}. + pca_kwargs : dict[str, Any], optional Keyword arguments passed to PCA, by default PCA is computed with default parameters. """ diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 1e7ae93aa..a6d3f3db9 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -2,10 +2,10 @@ from collections.abc import Sequence from typing import Literal, TypedDict -import numpy as np import torch import torch.nn.functional as F from lightning.pytorch import LightningModule +from numpy.typing import NDArray from pytorch_metric_learning.losses import NTXentLoss from torch import Tensor, nn from umap import UMAP @@ -29,7 +29,27 @@ class ContrastivePrediction(TypedDict): class ContrastiveModule(LightningModule): - """Contrastive Learning Model for self-supervised learning.""" + """Contrastive Learning Model for self-supervised learning. + + Parameters + ---------- + encoder : nn.Module | ContrastiveEncoder + Encoder model. + loss_function : nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss + Loss function. By default, nn.TripletMarginLoss with margin 0.5. + lr : float + Learning rate. By default, 1e-3. + schedule : Literal["WarmupCosine", "Constant"] + Schedule for learning rate. By default, "Constant". + log_batches_per_epoch : int + Number of batches to log. By default, 8. + log_samples_per_batch : int + Number of samples to log. By default, 1. + log_embeddings : bool + Whether to log embeddings. By default, False. + example_input_array_shape : Sequence[int] + Shape of example input array. + """ def __init__( self, @@ -86,7 +106,9 @@ def log_feature_statistics(self, embeddings: Tensor, prefix: str): _logger.debug(f"{prefix}_mean: {mean}") _logger.debug(f"{prefix}_std: {std}") - def print_embedding_norms(self, anchor, positive, negative, phase): + def print_embedding_norms( + self, anchor: Tensor, positive: Tensor, negative: Tensor, phase: str + ): """Log L2 norms of embeddings for triplet components. Parameters @@ -108,7 +130,12 @@ def print_embedding_norms(self, anchor, positive, negative, phase): _logger.debug(f"{phase}/negative_norm: {negative_norm}") def _log_metrics( - self, loss, anchor, positive, stage: Literal["train", "val"], negative=None + self, + loss: Tensor, + anchor: Tensor, + positive: Tensor, + stage: Literal["train", "val"], + negative: Tensor | None = None, ): self.log( f"loss/{stage}", @@ -144,7 +171,7 @@ def _log_metrics( sync_dist=True, ) - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + def _log_samples(self, key: str, imgs: Sequence[Sequence[NDArray]]): grid = render_images(imgs, cmaps=["gray"] * 3) self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" @@ -295,19 +322,19 @@ def on_validation_epoch_end(self) -> None: self.validation_step_outputs = [] - def configure_optimizers(self): + def configure_optimizers(self) -> torch.optim.AdamW: """Configure optimizer for contrastive learning. Returns ------- - torch.optim.Optimizer + torch.optim.AdamW AdamW optimizer with configured learning rate. """ optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr) return optimizer def predict_step( - self, batch: TripletSample, batch_idx, dataloader_idx=0 + self, batch: TripletSample, batch_idx: int, dataloader_idx: int = 0 ) -> ContrastivePrediction: """Prediction step for extracting embeddings.""" features, projections = self.model(batch["anchor"]) diff --git a/viscy/representation/evaluation/visualization.py b/viscy/representation/evaluation/visualization.py index f6e433ddc..df08d0e56 100644 --- a/viscy/representation/evaluation/visualization.py +++ b/viscy/representation/evaluation/visualization.py @@ -580,7 +580,7 @@ def update_figure( ) def update_track_timeline(clickData: dict[str, Any] | None) -> html.Div: """Update the track timeline based on the clicked point - + Parameters ---------- clickData: dict[str, Any] | None @@ -1632,7 +1632,7 @@ def _numpy_to_base64(img_array: NDArray) -> str: "utf-8" ) - def save_cache(self, cache_path: str |Path | None = None) -> None: + def save_cache(self, cache_path: str | Path | None = None) -> None: """Save the image cache to disk using pickle. Parameters diff --git a/viscy/representation/multi_modal.py b/viscy/representation/multi_modal.py index c0f1cfe94..51e429a45 100644 --- a/viscy/representation/multi_modal.py +++ b/viscy/representation/multi_modal.py @@ -96,7 +96,29 @@ def forward_projections( class JointContrastiveModule(ContrastiveModule): - """CLIP-style model pair for self-supervised cross-modality representation learning.""" + """CLIP-style model pair for self-supervised cross-modality representation learning. + + Parameters + ---------- + encoder : nn.Module | JointEncoders + Encoder model. + loss_function : nn.Module | nn.CosineEmbeddingLoss | nn.TripletMarginLoss | NTXentLoss + Loss function. By default, nn.TripletMarginLoss with margin 0.5. + lr : float + Learning rate. By default, 1e-3. + schedule : Literal["WarmupCosine", "Constant"] + Schedule for learning rate. By default, "Constant". + log_batches_per_epoch : int + Number of batches to log. By default, 8. + log_samples_per_batch : int + Number of samples to log. By default, 1. + log_embeddings : bool + Whether to log embeddings. By default, False. + example_input_array_shape : Sequence[int] + Shape of example input array. + prediction_arm : Literal["source", "target"] + Arm to use for prediction. By default, "source". + """ def __init__( self, diff --git a/viscy/transforms/_gaussian_blur.py b/viscy/transforms/_gaussian_blur.py index 1f32d9a99..1522d7de5 100644 --- a/viscy/transforms/_gaussian_blur.py +++ b/viscy/transforms/_gaussian_blur.py @@ -12,6 +12,25 @@ class RandomGaussianBlur(IntensityAugmentationBase3D): + """ + Random Gaussian Blur. + + Parameters + ---------- + kernel_size : tuple[int, int, int] | int + Kernel size. + sigma : tuple[float, float, float] | Tensor + Sigma. + border_type : str, optional + Border type. By default, "reflect". + same_on_batch : bool, optional + Whether to apply the same transformation to all batches. By default, False. + p : float, optional + Probability of applying the transformation. By default, 0.5. + keepdim : bool, optional + Whether to keep the dimensions of the input tensor. By default, False. + """ + def __init__( self, kernel_size: tuple[int, int, int] | int, @@ -44,6 +63,27 @@ def apply_transform( class BatchedRandGaussianBlurd(MapTransform, RandomizableTransform): + """ + Batched Random Gaussian Blur. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to apply the transformation to. + kernel_size : tuple[int, int] | int + Kernel size. + sigma : tuple[float, float] + Sigma. + border_type : str, optional + Border type. By default, "reflect". + same_on_batch : bool, optional + Whether to apply the same transformation to all batches. By default, False. + prob : float, optional + Probability of applying the transformation. By default, 0.1. + allow_missing_keys : bool, optional + Whether to allow missing keys. By default, False. + """ + def __init__( self, keys: str | Iterable[str], diff --git a/viscy/transforms/_transforms.py b/viscy/transforms/_transforms.py index d6b68e6b0..50ad7a812 100644 --- a/viscy/transforms/_transforms.py +++ b/viscy/transforms/_transforms.py @@ -76,7 +76,18 @@ def _normalize(): class RandInvertIntensityd(MapTransform, RandomizableTransform): - """Randomly invert the intensity of the image.""" + """ + Randomly invert the intensity of the image. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to invert the intensity of. + prob : float, optional + Probability of inverting the intensity. By default, 0.1. + allow_missing_keys : bool, optional + Whether to allow missing keys. By default, False. + """ def __init__( self, @@ -101,6 +112,15 @@ class TiledSpatialCropSamplesd(MapTransform, MultiSampleTrait): """Crop multiple tiled ROIs from an image. Used for deterministic cropping in validation. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to crop. + roi_size : tuple[int, int, int] + ROI size. + num_samples : int + Number of samples. """ def __init__( @@ -148,7 +168,13 @@ def __call__(self, sample: Sample) -> Sample: class StackChannelsd(MapTransform): - """Stack source and target channels.""" + """Stack source and target channels. + + Parameters + ---------- + channel_map : ChannelMap + Channel map. + """ def __init__(self, channel_map: ChannelMap) -> None: channel_names = [] @@ -165,7 +191,21 @@ def __call__(self, sample: Sample) -> Sample: class BatchedZoom(Transform): - """Batched zoom transform using ``torch.nn.functional.interpolate``.""" + """Batched zoom transform using ``torch.nn.functional.interpolate``. + + Parameters + ---------- + scale_factor : float | tuple[float, float, float] + Scale factor. + mode : Literal["nearest", "nearest-exact", "linear", "bilinear", "bicubic", "trilinear", "area"] + Mode. + align_corners : bool | None + Align corners. + recompute_scale_factor : bool | None + Recompute scale factor. + antialias : bool + Whether to use antialiasing. + """ def __init__( self, @@ -201,6 +241,8 @@ def __call__(self, sample: Tensor) -> Tensor: class BatchedScaleIntensityRangePercentiles(ScaleIntensityRangePercentiles): + """Batched scale intensity range percentiles.""" + def _normalize(self, img: Tensor) -> Tensor: q_low = self.lower / 100.0 q_high = self.upper / 100.0 @@ -245,6 +287,32 @@ def __call__(self, img: Tensor) -> Tensor: class BatchedScaleIntensityRangePercentilesd(MapTransform): + """Batched scale intensity range percentiles. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to scale. + lower : float + Lower percentile. + upper : float + Upper percentile. + b_min : float | None + Minimum value. + b_max : float | None + Maximum value. + clip : bool + Whether to clip the values. + relative : bool + Whether to use relative scaling. + channel_wise : bool + Whether to use channel-wise scaling. + dtype : DTypeLike + Data type. + allow_missing_keys : bool, optional + Whether to allow missing keys. By default, False. + """ + def __init__( self, keys: str | Iterable[str], @@ -271,6 +339,28 @@ def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]: class BatchedRandAffined(MapTransform): + """Batched random affine. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to affine. + prob : float, optional + Probability of affine. By default, 0.1. + rotate_range : Sequence[tuple[float, float] | float] | float | None + Rotate range. + shear_range : Sequence[tuple[float, float] | float] | float | None + Shear range. + translate_range : Sequence[tuple[float, float] | float] | float | None + Translate range. + scale_range : Sequence[tuple[float, float] | float] | float | None + Scale range. + mode : str, optional + Mode. By default, "bilinear". + allow_missing_keys : bool, optional + Whether to allow missing keys. By default, False. + """ + def __init__( self, keys: str | Iterable[str], @@ -334,6 +424,8 @@ def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: class RandGaussianNoiseTensor(RandGaussianNoise): + """Rand Gaussian Noise Tensor.""" + def randomize(self, img: Tensor, mean: float | None = None) -> None: self._do_transform = self.R.rand() < self.prob if not self._do_transform: @@ -349,6 +441,26 @@ def randomize(self, img: Tensor, mean: float | None = None) -> None: class RandGaussianNoiseTensord(RandGaussianNoised): + """Rand Gaussian Noise Tensor. + + Parameters + ---------- + keys : str | Iterable[str] + Keys to noise. + prob : float, optional + Probability of noise. By default, 0.1. + mean : float, optional + Mean. By default, 0.0. + std : float, optional + Standard deviation. By default, 0.1. + dtype : DTypeLike, optional + Data type. By default, np.float32. + allow_missing_keys : bool, optional + Whether to allow missing keys. By default, False. + sample_std : bool, optional + Whether to sample the standard deviation. By default, True. + """ + def __init__( self, keys: str | Iterable[str], diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 22e6f3e40..40eb0dd4c 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -13,7 +13,7 @@ from monai.optimizers import WarmupCosineSchedule from monai.transforms import DivisiblePad, Rotate90 from torch import Tensor, nn -from torch.optim.lr_scheduler import ConstantLR +from torch.optim.lr_scheduler import ConstantLR, LRScheduler from torchmetrics.functional import ( accuracy, cosine_similarity, @@ -73,19 +73,19 @@ def __init__( self.ms_dssim_alpha = ms_dssim_alpha @torch.amp.custom_fwd(device_type="cuda", cast_inputs=torch.float32) - def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward(self, preds: Tensor, target: Tensor) -> Tensor: """Compute mixed reconstruction loss. Parameters ---------- - preds : torch.Tensor + preds : Tensor Predicted tensor - target : torch.Tensor + target : Tensor Target tensor Returns ------- - torch.Tensor + Tensor Combined loss value """ loss = 0 @@ -109,23 +109,21 @@ class MaskedMSELoss(nn.Module): Computes MSE loss only for masked regions. """ - def forward( - self, preds: torch.Tensor, original: torch.Tensor, mask: torch.Tensor - ) -> torch.Tensor: + def forward(self, preds: Tensor, original: Tensor, mask: Tensor) -> Tensor: """Compute masked MSE loss. Parameters ---------- - preds : torch.Tensor + preds : Tensor Predicted tensor. - original : torch.Tensor + original : Tensor Original tensor. - mask : torch.Tensor + mask : Tensor Binary mask tensor. Returns ------- - torch.Tensor + Tensor Masked MSE loss value. """ loss = F.mse_loss(preds, original, reduction="none") @@ -240,31 +238,29 @@ def forward(self, x: Tensor) -> Tensor: Parameters ---------- - x : torch.Tensor + x : Tensor Input tensor. Returns ------- - torch.Tensor + Tensor Model output. """ return self.model(x) - def training_step( - self, batch: Sample | Sequence[Sample], batch_idx: int - ) -> torch.Tensor: + def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int) -> Tensor: """Execute single training step. Parameters ---------- - batch : Sample or Sequence[Sample] + batch : Sample | Sequence[Sample] Training batch data. batch_idx : int Batch index. Returns ------- - torch.Tensor + Tensor Training loss. """ losses = [] @@ -306,8 +302,8 @@ def validation_step( Validation batch data. batch_idx : int Batch index. - dataloader_idx : int, default=0 - Dataloader index for multi-dataloader validation. + dataloader_idx : int + Dataloader index for multi-dataloader validation. By default, 0. """ source: Tensor = batch["source"] target: Tensor = batch["target"] @@ -390,7 +386,7 @@ def _log_regression_metrics(self, pred: Tensor, target: Tensor) -> None: on_epoch=True, ) - def _cellpose_predict(self, pred: Tensor, name: str) -> torch.ShortTensor: + def _cellpose_predict(self, pred: Tensor, name: str) -> Tensor: pred_labels_np = self.cellpose_model.eval( pred.cpu().numpy(), channels=[0, 0], diameter=self.test_cellpose_diameter )[0].astype(np.int16) @@ -398,7 +394,7 @@ def _cellpose_predict(self, pred: Tensor, name: str) -> torch.ShortTensor: return torch.from_numpy(pred_labels_np).to(self.device) def _log_segmentation_metrics( - self, pred_labels: torch.ShortTensor, target_labels: torch.ShortTensor + self, pred_labels: Tensor, target_labels: Tensor ) -> None: compute = pred_labels is not None if compute: @@ -440,7 +436,7 @@ def _log_segmentation_metrics( def predict_step( self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 - ) -> dict[str, Any]: + ) -> Tensor: """Execute single prediction step. Parameters @@ -454,7 +450,7 @@ def predict_step( Returns ------- - torch.Tensor + Tensor Model prediction. """ source = batch["source"] @@ -475,12 +471,12 @@ def perform_test_time_augmentations(self, source: Tensor) -> Tensor: Parameters ---------- - source : torch.Tensor + source : Tensor Input tensor. Returns ------- - torch.Tensor + Tensor Aggregated prediction. """ # Save the yx coords to crop post rotations @@ -531,7 +527,13 @@ def on_validation_epoch_end(self) -> None: self.validation_losses.clear() def on_test_start(self) -> None: - """Load CellPose model for segmentation.""" + """Load CellPose model for segmentation. + + Raises + ------ + ImportError + If CellPose is not installed. + """ if self.test_cellpose_model_path is not None: try: from cellpose.models import CellposeModel @@ -555,13 +557,15 @@ def on_predict_start(self) -> None: down_factor = 2**self.model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[Any]]: + def configure_optimizers( + self, + ) -> tuple[list[torch.optim.Optimizer], list[LRScheduler]]: """Configure optimizer and learning rate scheduler. Returns ------- - tuple - Tuple containing optimizer and scheduler lists. + tuple[list[torch.optim.Optimizer], list[LRScheduler]] + Tuple containing a list of optimizers and schedulers. """ if self.freeze_encoder: self.model: FullyConvolutionalMAE @@ -669,12 +673,12 @@ def forward(self, x: Tensor) -> Tensor: Parameters ---------- - x : torch.Tensor + x : Tensor Input tensor. Returns ------- - torch.Tensor + Tensor Model output. """ return self.model(x) @@ -702,12 +706,12 @@ def _reduce_predictions(self, preds: list[Tensor]) -> Tensor: Parameters ---------- - preds : list[torch.Tensor] + preds : list[Tensor] List of prediction tensors. Returns ------- - torch.Tensor + Tensor Reduced prediction tensor. """ prediction = torch.stack(preds, dim=0) @@ -733,7 +737,7 @@ def predict_step( Returns ------- - torch.Tensor + Tensor Aggregated prediction from augmented inputs. """ source = batch["source"] @@ -811,14 +815,14 @@ def forward( Parameters ---------- - x : torch.Tensor + x : Tensor Input tensor. mask_ratio : float, default=0.0 Masking ratio for FCMAE mode. Returns ------- - torch.Tensor or tuple + Tensor or tuple Model output, optionally with mask if mask_ratio > 0. """ return self.model(x, mask_ratio) @@ -837,7 +841,7 @@ def forward_fit_fcmae( Returns ------- - tuple[torch.Tensor, torch.Tensor or None, torch.Tensor] + tuple[Tensor, Tensor or None, Tensor] Prediction, target (if requested), and loss. """ x = batch["source"] @@ -859,7 +863,7 @@ def forward_fit_supervised(self, batch: Sample) -> tuple[Tensor, Tensor, Tensor] Returns ------- - tuple[torch.Tensor, torch.Tensor, torch.Tensor] + tuple[Tensor, Tensor, Tensor] Prediction, target, and loss. """ x = batch["source"] @@ -885,7 +889,7 @@ def forward_fit_task( Returns ------- - tuple[torch.Tensor, torch.Tensor or None, torch.Tensor] + tuple[Tensor, Tensor | None, Tensor] Prediction, target, and loss. """ if self.model.pretraining: @@ -902,7 +906,7 @@ def train_transform_and_collate(self, batch: list[dict[str, Tensor]]) -> Sample: Parameters ---------- - batch : list[dict[str, torch.Tensor]] + batch : list[dict[str, Tensor]] List of batch dictionaries from multiple data modules. Returns @@ -933,7 +937,7 @@ def val_transform_and_collate( Returns ------- - torch.Tensor + Tensor Collated and transformed batch. """ batch = self.datamodules[dataloader_idx].val_gpu_transforms(batch) @@ -951,7 +955,7 @@ def training_step(self, batch: list[list[Sample]], batch_idx: int) -> Tensor: Returns ------- - torch.Tensor + Tensor Training loss. """ batch = self.train_transform_and_collate(batch) diff --git a/viscy/translation/evaluation.py b/viscy/translation/evaluation.py index 78376548f..42ef08c48 100644 --- a/viscy/translation/evaluation.py +++ b/viscy/translation/evaluation.py @@ -12,7 +12,13 @@ class SegmentationMetrics2D(LightningModule): - """Test runner for 2D segmentation.""" + """Test runner for 2D segmentation. + + Parameters + ---------- + aggregate_epoch : bool, optional + Whether to aggregate the metrics over the epoch. Defaults to False. + """ def __init__(self, aggregate_epoch: bool = False) -> None: super().__init__() diff --git a/viscy/translation/evaluation_metrics.py b/viscy/translation/evaluation_metrics.py index 598b52a69..7403e5f08 100644 --- a/viscy/translation/evaluation_metrics.py +++ b/viscy/translation/evaluation_metrics.py @@ -1,20 +1,18 @@ """Metrics for model evaluation.""" from collections.abc import Sequence -from typing import Union -from warnings import warn import numpy as np import torch import torch.nn.functional as F from monai.metrics.regression import compute_ssim_and_cs +from numpy.typing import NDArray from scipy.optimize import linear_sum_assignment from skimage.measure import label, regionprops from torchmetrics.detection.mean_ap import MeanAveragePrecision -from torchvision.ops import masks_to_boxes -def VOI_metric(target, prediction): +def VOI_metric(target: NDArray, prediction: NDArray) -> list[float]: """ Variation of information metric. @@ -22,14 +20,14 @@ def VOI_metric(target, prediction): Parameters ---------- - target : np.array + target : NDArray Ground truth mask. - prediction : np.array + prediction : NDArray Model inferred FL image cellpose mask. Returns ------- - list of float + list[float] VI for image masks. """ # cellpose segmentation of predicted image: outputs labl mask @@ -67,21 +65,28 @@ def VOI_metric(target, prediction): return [VI] -def POD_metric(target_bin, pred_bin): +def POD_metric( + target_bin: NDArray, pred_bin: NDArray +) -> tuple[float, float, float, int, int]: """ Probability of detection metric for object matching. Parameters ---------- - target_bin : array_like + target_bin : NDArray Binary ground truth mask. - pred_bin : array_like + pred_bin : NDArray Binary predicted mask. Returns ------- - tuple + tuple[float, float, float, int, int] POD and various detection statistics. + - POD: Probability of detection + - FAR: False alarm rate + - PCD: Probability of correct detection + - n_targObj: Number of target objects + - n_predObj: Number of predicted objects """ # pred_bin = cpmask_array(prediction) @@ -136,7 +141,7 @@ def POD_metric(target_bin, pred_bin): # probability of correct detection PCD = len(matching_targ) / len(props_targ) - return [POD, FAR, PCD, len(props_targ), len(props_pred)] + return (POD, FAR, PCD, len(props_targ), len(props_pred)) def compute_3d_dice_score( @@ -146,7 +151,26 @@ def compute_3d_dice_score( threshold: float = 0.5, aggregate: bool = True, ) -> torch.Tensor: - """Compute 3D Dice similarity coefficient.""" + """Compute 3D Dice similarity coefficient. + + Parameters + ---------- + y_true : torch.Tensor + True labels. + y_pred : torch.Tensor + Predicted labels. + eps : float, optional + Epsilon to avoid division by zero. Defaults to 1e-8. + threshold : float, optional + Threshold for binarization. Defaults to 0.5. + aggregate : bool, optional + Whether to aggregate the dice score. Defaults to True. + + Returns + ------- + torch.Tensor + Dice score. + """ y_pred_thresholded = (y_pred > threshold).float() intersection = torch.sum(y_true * y_pred_thresholded, dim=(-3, -2, -1)) total = torch.sum(y_true + y_pred_thresholded, dim=(-3, -2, -1)) @@ -161,7 +185,22 @@ def compute_jaccard_index( y_pred: torch.Tensor, threshold: float = 0.5, ) -> torch.Tensor: - """Compute Jaccard index (IoU).""" + """Compute Jaccard index (IoU). + + Parameters + ---------- + y_true : torch.Tensor + True labels. + y_pred : torch.Tensor + Predicted labels. + threshold : float, optional + Threshold for binarization. Defaults to 0.5. + + Returns + ------- + torch.Tensor + Jaccard index. + """ y_pred_thresholded = y_pred > threshold intersection = torch.sum(y_true & y_pred_thresholded, dim=(-3, -2, -1)) union = torch.sum(y_true | y_pred_thresholded, dim=(-3, -2, -1)) @@ -171,7 +210,22 @@ def compute_jaccard_index( def compute_pearson_correlation_coefficient( y_true: torch.Tensor, y_pred: torch.Tensor, dim: Sequence[int] | None = None ) -> torch.Tensor: - """Compute Pearson correlation coefficient.""" + """Compute Pearson correlation coefficient. + + Parameters + ---------- + y_true : torch.Tensor + True labels. + y_pred : torch.Tensor + Predicted labels. + dim : Sequence[int] | None, optional + Dimensions to compute the Pearson correlation coefficient. Defaults to None. + + Returns + ------- + torch.Tensor + Pearson correlation coefficient. + """ if dim is None: # default to spatial dimensions dim = (-3, -2, -1) @@ -192,7 +246,20 @@ def compute_pearson_correlation_coefficient( class MeanAveragePrecisionNuclei(MeanAveragePrecision): - """Mean Average Precision for nuclei detection.""" + """Mean Average Precision for nuclei detection. + + Parameters + ---------- + min_area : int, optional + Minimum area of nuclei to be considered. Defaults to 20. + iou_threshold : float, optional + IoU threshold for matching. Defaults to 0.5. + + Returns + ------- + torch.Tensor + Mean average precision score. + """ def __init__(self, min_area: int = 20, iou_threshold: float = 0.5) -> None: super().__init__(iou_thresholds=[iou_threshold]) @@ -283,7 +350,7 @@ def ssim_loss_25d( Returns ------- - torch.Tensor or tuple[torch.Tensor, torch.Tensor] + torch.Tensor | tuple[torch.Tensor, torch.Tensor] SSIM for the batch, optionally with contrast sensitivity. """ if preds.ndim != 5: diff --git a/viscy/translation/predict_writer.py b/viscy/translation/predict_writer.py index 96d34971d..29c0e41be 100644 --- a/viscy/translation/predict_writer.py +++ b/viscy/translation/predict_writer.py @@ -187,7 +187,7 @@ def write_on_batch_end( PyTorch Lightning module being used for predictions. prediction : torch.Tensor Batch of predictions from the model. - batch_indices : Optional[Sequence[int]] + batch_indices : Sequence[int] | None Indices of the batch samples. batch : Sample Input batch data. diff --git a/viscy/unet/networks/Unet25D.py b/viscy/unet/networks/Unet25D.py index 802bec409..41cda8628 100644 --- a/viscy/unet/networks/Unet25D.py +++ b/viscy/unet/networks/Unet25D.py @@ -230,7 +230,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: between them (decoder) => terminal block collapses to output dimensions - :param torch.tensor x: input image + Parameters + ---------- + x : torch.Tensor + Input image. + + Returns + ------- + torch.Tensor + Output image. """ # encoder skip_tensors = [] diff --git a/viscy/unet/networks/Unet2D.py b/viscy/unet/networks/Unet2D.py index e454942d7..225b43523 100644 --- a/viscy/unet/networks/Unet2D.py +++ b/viscy/unet/networks/Unet2D.py @@ -184,7 +184,7 @@ def __init__( kernel_size=self.kernel_size, ) - def forward(self, x, validate_input=False): + def forward(self, x: torch.Tensor, validate_input: bool = False) -> torch.Tensor: """Forward pass through the 2D U-Net. Call order: @@ -234,7 +234,7 @@ def forward(self, x, validate_input=False): return x.unsqueeze(2) - def register_modules(self, module_list, name): + def register_modules(self, module_list: list[nn.Module], name: str) -> None: """Helper function that registers modules stored in a list to the model object. So that they can be seen by PyTorch optimizer. diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index c529b0a13..51abeed86 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -42,10 +42,19 @@ def generate_mask( ) -> BoolTensor: """Generate random boolean mask for masked autoencoder training. - :param Size target: target shape - :param int stride: total stride - :param float mask_ratio: ratio of the pixels to mask - :return BoolTensor: boolean mask (B1HW) + Parameters + ---------- + target : Size + Target tensor shape. + stride : int + Total downsampling stride. + mask_ratio : float + Ratio of pixels to mask for training. + + Returns + ------- + BoolTensor + Boolean mask tensor of shape (B1HW). """ m_height = target[-2] // stride m_width = target[-1] // stride @@ -58,9 +67,17 @@ def generate_mask( def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: """Upsample boolean mask to match target spatial dimensions. - :param BoolTensor mask: low-resolution boolean mask (B1HW) - :param Size target: target size (BCHW) - :return BoolTensor: upsampled boolean mask (B1HW) + Parameters + ---------- + mask : BoolTensor + Low-resolution boolean mask of shape (B1HW). + target : Size + Target tensor size (BCHW). + + Returns + ------- + BoolTensor + Upsampled boolean mask of shape (B1HW). """ if target[-2:] != mask.shape[-2:]: if not all(i % j == 0 for i, j in zip(target, mask.shape)): @@ -77,9 +94,17 @@ def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: def masked_patchify(features: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """Convert spatial features to channel-last patches, optionally masked. - :param Tensor features: input image features (BCHW) - :param BoolTensor unmasked: boolean foreground mask (B1HW) - :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) + Parameters + ---------- + features : Tensor + Input image features of shape (BCHW). + unmasked : BoolTensor | None, optional + Boolean foreground mask of shape (B1HW), by default None. + + Returns + ------- + Tensor + Masked channel-last features of shape (BLC, L = H * W * mask_ratio). """ if unmasked is None: return features.flatten(2).permute(0, 2, 1) @@ -96,10 +121,19 @@ def masked_unpatchify( ) -> Tensor: """Convert channel-last patches back to spatial features. - :param Tensor features: dense channel-last features (BLC) - :param Size out_shape: output shape (BCHW) - :param BoolTensor | None unmasked: boolean foreground mask, defaults to None - :return Tensor: masked features (BCHW) + Parameters + ---------- + features : Tensor + Dense channel-last features of shape (BLC). + out_shape : Size + Output tensor shape (BCHW). + unmasked : BoolTensor | None, optional + Boolean foreground mask, by default None. + + Returns + ------- + Tensor + Masked spatial features of shape (BCHW). """ if unmasked is None: return features.permute(0, 2, 1).reshape(out_shape) @@ -115,12 +149,20 @@ def masked_unpatchify( class MaskedConvNeXtV2Block(nn.Module): """Masked ConvNeXt V2 Block. - :param int in_channels: input channels - :param int | None out_channels: output channels, defaults to None - :param int kernel_size: depth-wise convolution kernel size, defaults to 7 - :param int stride: downsample stride, defaults to 1 - :param int mlp_ratio: MLP expansion ratio, defaults to 4 - :param float drop_path: drop path rate, defaults to 0.0 + Parameters + ---------- + in_channels : int + Input channels. + out_channels : int | None, optional + Output channels, by default None. + kernel_size : int, optional + Depth-wise convolution kernel size, by default 7. + stride : int, optional + Downsample stride, by default 1. + mlp_ratio : int, optional + MLP expansion ratio, by default 4. + drop_path : float, optional + Drop path rate, by default 0.0. """ def __init__( @@ -157,9 +199,17 @@ def __init__( def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """Forward pass through masked ConvNeXt V2 block. - :param Tensor x: input tensor (BCHW) - :param BoolTensor | None unmasked: boolean foreground mask, defaults to None - :return Tensor: output tensor (BCHW) + Parameters + ---------- + x : Tensor + Input tensor of shape (BCHW). + unmasked : BoolTensor | None, optional + Boolean foreground mask, by default None. + + Returns + ------- + Tensor + Output tensor of shape (BCHW). """ shortcut = self.shortcut(x) if unmasked is not None: @@ -177,15 +227,22 @@ def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: class MaskedConvNeXtV2Stage(nn.Module): - """Masked ConvNeXt V2 Stage. - - :param int in_channels: input channels - :param int out_channels: output channels - :param int kernel_size: depth-wise convolution kernel size, defaults to 7 - :param int stride: downsampling factor of this stage, defaults to 2 - :param int num_blocks: number of residual blocks, defaults to 2 - :param Sequence[float] | None drop_path_rates: drop path rates of each block, - defaults to None + """Masked ConvNeXt V2 Stage for hierarchical feature extraction. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : int, optional + Depth-wise convolution kernel size, by default 7. + stride : int, optional + Downsampling factor of this stage, by default 2. + num_blocks : int, optional + Number of residual blocks, by default 2. + drop_path_rates : Sequence[float] | None, optional + Drop path rates of each block, by default None. """ def __init__( @@ -236,9 +293,17 @@ def __init__( def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """Forward pass through masked ConvNeXt V2 stage. - :param Tensor x: input tensor (BCHW) - :param BoolTensor | None unmasked: boolean foreground mask, defaults to None - :return Tensor: output tensor (BCHW) + Parameters + ---------- + x : Tensor + Input tensor of shape (BCHW). + unmasked : BoolTensor | None, optional + Boolean foreground mask, by default None. + + Returns + ------- + Tensor + Output tensor of shape (BCHW). """ x = self.downsample(x) if unmasked is not None: @@ -249,14 +314,20 @@ def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: class MaskedAdaptiveProjection(nn.Module): - """ - Masked patchifying layer for projecting 2D or 3D input into 2D feature maps. - - :param int in_channels: input channels - :param int out_channels: output channels - :param Sequence[int, int] | int kernel_size_2d: kernel width and height - :param int kernel_depth: kernel depth for 3D input - :param int in_stack_depth: input stack depth for 3D input + """Masked patchifying layer for projecting 2D or 3D input into 2D feature maps. + + Parameters + ---------- + in_channels : int + Input channels. + out_channels : int + Output channels. + kernel_size_2d : tuple[int, int] | int, optional + Kernel width and height, by default 4. + kernel_depth : int, optional + Kernel depth for 3D input, by default 5. + in_stack_depth : int, optional + Input stack depth for 3D input, by default 5. """ def __init__( @@ -289,9 +360,17 @@ def __init__( def forward(self, x: Tensor, unmasked: BoolTensor = None) -> Tensor: """Forward pass through masked adaptive projection layer. - :param Tensor x: input tensor (BCDHW) - :param BoolTensor unmasked: boolean foreground mask (B1HW), defaults to None - :return Tensor: output tensor (BCHW) + Parameters + ---------- + x : Tensor + Input tensor of shape (BCDHW). + unmasked : BoolTensor, optional + Boolean foreground mask of shape (B1HW), by default None. + + Returns + ------- + Tensor + Output tensor of shape (BCHW). """ # no need to mask before convolutions since patches do not spill over if x.shape[2] > 1: @@ -317,12 +396,20 @@ class MaskedMultiscaleEncoder(nn.Module): Implements hierarchical feature extraction through multiple ConvNeXt V2 stages with optional random masking for self-supervised pretraining. - :param int in_channels: input channels - :param Sequence[int] stage_blocks: number of blocks per encoder stage - :param Sequence[int] dims: feature dimensions at each stage - :param float drop_path_rate: stochastic depth rate - :param Sequence[int] stem_kernel_size: kernel sizes for adaptive projection - :param int in_stack_depth: input stack depth for 3D input + Parameters + ---------- + in_channels : int + Input channels. + stage_blocks : Sequence[int], optional + Number of blocks per encoder stage, by default (3, 3, 9, 3). + dims : Sequence[int], optional + Feature dimensions at each stage, by default (96, 192, 384, 768). + drop_path_rate : float, optional + Stochastic depth rate, by default 0.0. + stem_kernel_size : Sequence[int], optional + Kernel sizes for adaptive projection, by default (5, 4, 4). + in_stack_depth : int, optional + Input stack depth for 3D input, by default 5. """ def __init__( @@ -364,11 +451,18 @@ def forward( ) -> tuple[list[Tensor], BoolTensor | None]: """Extract multi-scale features with optional masking. - :param Tensor x: input tensor (BCDHW) - :param float mask_ratio: ratio of the feature maps to mask, - defaults to 0.0 (no masking) - :return list[Tensor]: output tensors (list of BCHW) - :return BoolTensor | None: boolean foreground mask, None if no masking + Parameters + ---------- + x : Tensor + Input tensor of shape (BCDHW). + mask_ratio : float, optional + Ratio of the feature maps to mask, by default 0.0 (no masking). + + Returns + ------- + tuple[list[Tensor], BoolTensor | None] + Output tensors as list of BCHW tensors and boolean foreground mask + (None if no masking). """ if mask_ratio > 0.0: mask = generate_mask( @@ -393,11 +487,18 @@ class PixelToVoxelShuffleHead(nn.Module): Converts 2D feature maps to 3D output volumes through pixel shuffle upsampling and channel-to-depth reshaping. - :param int in_channels: input feature channels - :param int out_channels: output channels per voxel - :param int out_stack_depth: output stack depth (Z dimension) - :param int xy_scaling: spatial upsampling factor - :param bool pool: whether to apply pooling in upsampling + Parameters + ---------- + in_channels : int + Input feature channels. + out_channels : int + Output channels per voxel. + out_stack_depth : int, optional + Output stack depth (Z dimension), by default 5. + xy_scaling : int, optional + Spatial upsampling factor, by default 4. + pool : bool, optional + Whether to apply pooling in upsampling, by default False. """ def __init__( @@ -424,8 +525,15 @@ def __init__( def forward(self, x: Tensor) -> Tensor: """Reconstruct 3D volume from 2D features. - :param Tensor x: input 2D features (BCHW) - :return Tensor: reconstructed 3D volume (BCDHW) + Parameters + ---------- + x : Tensor + Input 2D features of shape (BCHW). + + Returns + ------- + Tensor + Reconstructed 3D volume of shape (BCDHW). """ x = self.upsample(x) b, _, h, w = x.shape @@ -440,20 +548,32 @@ class FullyConvolutionalMAE(nn.Module): with a UNet-style decoder for reconstruction tasks. Supports both pretraining with masking and fine-tuning for downstream tasks. - # TODO: MANUAL_REVIEW - Complex encoder-decoder architecture with masking - - :param int in_channels: input channels - :param int out_channels: output channels - :param Sequence[int] encoder_blocks: blocks per encoder stage - :param Sequence[int] dims: feature dimensions per stage - :param float encoder_drop_path_rate: encoder stochastic depth rate - :param Sequence[int] stem_kernel_size: adaptive projection kernel sizes - :param int in_stack_depth: input stack depth for 3D data - :param int decoder_conv_blocks: decoder convolution blocks per stage - :param bool pretraining: whether in pretraining mode (returns mask) - :param bool head_conv: whether to use convolutional reconstruction head - :param int head_conv_expansion_ratio: expansion ratio for conv head - :param bool head_conv_pool: whether to use pooling in conv head + Parameters + ---------- + in_channels : int + Input channels. + out_channels : int + Output channels. + encoder_blocks : Sequence[int], optional + Blocks per encoder stage, by default [3, 3, 9, 3]. + dims : Sequence[int], optional + Feature dimensions per stage, by default [96, 192, 384, 768]. + encoder_drop_path_rate : float, optional + Encoder stochastic depth rate, by default 0.0. + stem_kernel_size : Sequence[int], optional + Adaptive projection kernel sizes, by default (5, 4, 4). + in_stack_depth : int, optional + Input stack depth for 3D data, by default 5. + decoder_conv_blocks : int, optional + Decoder convolution blocks per stage, by default 1. + pretraining : bool, optional + Whether in pretraining mode (returns mask), by default True. + head_conv : bool, optional + Whether to use convolutional reconstruction head, by default False. + head_conv_expansion_ratio : int, optional + Expansion ratio for conv head, by default 4. + head_conv_pool : bool, optional + Whether to use pooling in conv head, by default True. """ def __init__( @@ -519,15 +639,25 @@ def __init__( self.num_blocks = len(dims) * int(math.log2(stem_kernel_size[-1])) self.pretraining = pretraining - def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + def forward( + self, x: Tensor, mask_ratio: float = 0.0 + ) -> Tensor | tuple[Tensor, BoolTensor]: """Forward pass through FC-MAE architecture. Encodes input with optional masking, decodes through UNet decoder, and reconstructs output through pixel-to-voxel head. - :param Tensor x: input tensor (BCDHW) - :param float mask_ratio: masking ratio for pretraining (0.0 = no mask) - :return Tensor: reconstructed output (BCDHW) or tuple with mask + Parameters + ---------- + x : Tensor + Input tensor of shape (BCDHW). + mask_ratio : float, optional + Masking ratio for pretraining, by default 0.0 (no mask). + + Returns + ------- + Tensor | tuple[Tensor, BoolTensor] + Reconstructed output of shape (BCDHW) or tuple with mask. """ x, mask = self.encoder(x, mask_ratio=mask_ratio) x.reverse() diff --git a/viscy/unet/networks/layers/ConvBlock2D.py b/viscy/unet/networks/layers/ConvBlock2D.py index c5f35c9e5..a713edfad 100644 --- a/viscy/unet/networks/layers/ConvBlock2D.py +++ b/viscy/unet/networks/layers/ConvBlock2D.py @@ -31,29 +31,33 @@ def __init__( ) -> None: """Initialize convolutional block for lateral layers in U-Net. - Format for layer initialization is as follows: - if layer type specified - => for number of layers - => add layer to list of that layer type - => register elements of list - This is done to allow for dynamic layer number specification in the conv blocks, - which allows us to change the parameter numbers of the network. - - :param int in_filters: number of images in in stack - :param int out_filters: number of images in out stack - :param float dropout: dropout probability (False => 0) - :param str norm: normalization type: 'batch', 'instance' - :param bool residual: as name - :param str activation: activation function: 'relu', 'leakyrelu', 'elu', 'selu' - :param bool transpose: as name - :param int/tuple kernel_size: convolutional kernel size - :param int num_repeats: number of times the layer_order layer sequence - is repeated in the block - :param str filter_steps: determines where in the block - the filters inflate channels (learn abstraction information): - 'linear','first','last' - :param str layer_order: order of conv, norm, and act layers in block: - 'can', 'cna', 'nca', etc + Format for layer initialization allows dynamic layer number specification + in the conv blocks, enabling parameter number flexibility across the network. + + Parameters + ---------- + in_filters : int + Number of input feature channels. + out_filters : int + Number of output feature channels. + dropout : float or bool, default=False + Dropout probability. If False, no dropout is applied. + norm : {"batch", "instance"}, default="batch" + Normalization type to apply. + residual : bool, default=True + Whether to include residual connections. + activation : {"relu", "leakyrelu", "elu", "selu", "linear"}, default="relu" + Activation function type. + transpose : bool, default=False + Whether to use transpose convolution layers. + kernel_size : int or tuple[int, int], default=3 + 2D convolutional kernel size. + num_repeats : int, default=3 + Number of times the layer_order sequence is repeated in the block. + filter_steps : {"linear", "first", "last"}, default="first" + Strategy for channel dimension changes across layers. + layer_order : str, default="can" + Order of conv (c), activation (a), normalization (n) layers. """ super().__init__() self.in_filters = in_filters @@ -274,16 +278,14 @@ def forward(self, x: torch.Tensor, validate_input: bool = False) -> torch.Tensor """Forward pass through the convolutional block. Order of layers within the block is defined by the 'layer_order' parameter, - which is a string of 'c's, 'a's and 'n's - in reference to convolution, activation, and normalization layers. - This sequence is repeated num_repeats times. + which is a string of 'c's, 'a's and 'n's in reference to convolution, + activation, and normalization layers. This sequence is repeated num_repeats times. - Recommended layer order: convolution -> activation -> normalization + Recommended layer order: convolution -> activation -> normalization - Regardless of layer order, - the final layer sequence in the block always ends in activation. - This allows for usage of passthrough layers - or a final output activation function determined separately. + Regardless of layer order, the final layer sequence in the block always ends + in activation. This allows for usage of passthrough layers or a final output + activation function determined separately. Residual blocks: if input channels are greater than output channels, @@ -291,9 +293,18 @@ def forward(self, x: torch.Tensor, validate_input: bool = False) -> torch.Tensor if input channels are less than output channels, we zero-pad input channels to output channel size. - :param torch.tensor x: input tensor - :param bool validate_input: Deactivates assertions - which are redundant if forward pass is being traced by tensorboard writer. + Parameters + ---------- + x : torch.Tensor + Input tensor for convolutional processing. + validate_input : bool, default=False + Deactivates assertions which are redundant if forward pass is being + traced by tensorboard writer. + + Returns + ------- + torch.Tensor + Output tensor after convolutional block processing. """ if validate_input: if isinstance(self.kernel_size, int): diff --git a/viscy/unet/networks/layers/ConvBlock3D.py b/viscy/unet/networks/layers/ConvBlock3D.py index 8f5339277..3840c43a0 100644 --- a/viscy/unet/networks/layers/ConvBlock3D.py +++ b/viscy/unet/networks/layers/ConvBlock3D.py @@ -62,48 +62,6 @@ def __init__( layer_order: str = "can", padding: str | int | tuple[int, ...] | None = None, ) -> None: - """ - Convolutional block for lateral layers in Unet. - - This block only accepts tensors of dimensions in - order [...,z,x,y] or [...,z,y,x] - - Format for layer initialization is as follows: - if layer type specified - => for number of layers - => add layer to list of that layer type - This is done to allow for dynamic layer number specification in the conv blocks, - which allows us to change the parameter numbers of the network. - - Only 'same' convolutional padding is recommended, - as the conv blocks are intended for deep Unets. - However padding can be specified as follows: - padding -> token{'same', 'valid', 'valid_stack'} or tuple(int) or int: - -> 'same': pads with same convolution - -> 'valid': pads for valid convolution on all dimensions - -> 'valid_stack': pads for valid convolution on xy dims (-1, -2), - same on z dim (-3). - -> tuple (int): pads above and below corresponding dimensions - -> int: pads above and below all dimensions - - :param int in_filters: number of images in in stack - :param int out_filters: number of images in out stack - :param float dropout: dropout probability (False => 0) - :param str norm: normalization type: 'batch', 'instance' - :param bool residual: as name - :param str activation: activation function: 'relu', 'leakyrelu', 'elu', 'selu' - :param bool transpose: as name - :param int/tuple kernel_size: convolutional kernel size - :param int num_repeats: as name - :param str filter_steps: determines where in the block - the filters inflate channels - (learn abstraction information): 'linear','first','last' - :param str layer_order: order of conv, norm, and act layers in block: - 'can', 'cna', etc. - NOTE: for now conv must always come first as required by norm feature counts - :paramn str/tuple(int)/tuple/None padding: convolutional padding, - see docstring for details - """ super().__init__() self.in_filters = in_filters self.out_filters = out_filters @@ -308,7 +266,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if input channels are less than output channels, we zero-pad input channels to output channel size - :param torch.tensor x: input tensor + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor. """ x_0 = x for i in range(self.num_repeats): @@ -356,7 +322,7 @@ def model(self) -> nn.Sequential: """ Allows calling of parameters inside ConvBlock object. - Layer order: convolution -> normalization -> activation + Layer order: convolution -> normalization -> activation We can make a list of layer modules and unpack them into nn.Sequential. Note: this is distinct from the forward call @@ -364,6 +330,11 @@ def model(self) -> nn.Sequential: since this is a residual block. The forward call performs the residual calculation, and all the parameters can be seen by the optimizer when given this model. + + Returns + ------- + nn.Sequential + Sequential model containing all layers in the block. """ layers = [] @@ -385,8 +356,12 @@ def register_modules(self, module_list: list[nn.Module], name: str) -> None: Used to enable model graph creation with non-sequential model types and dynamic layer numbers - :param list(torch.nn.module) module_list: list of modules to register - :param str name: name of module type + Parameters + ---------- + module_list : list[torch.nn.Module] + List of modules to register. + name : str + Name of module type. """ for i, module in enumerate(module_list): self.add_module(f"{name}_{str(i)}", module) diff --git a/viscy/unet/networks/unext2.py b/viscy/unet/networks/unext2.py index 303bd0a67..e100ddc8c 100644 --- a/viscy/unet/networks/unext2.py +++ b/viscy/unet/networks/unext2.py @@ -71,7 +71,19 @@ def _get_convnext_stage( class UNeXt2Stem(nn.Module): - """Stem for UNeXt2 and ContrastiveEncoder networks.""" + """Stem for UNeXt2 and ContrastiveEncoder networks. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + kernel_size : tuple[int, int, int] + Kernel size. + in_stack_depth : int + Number of input stack depth. + """ def __init__( self, @@ -94,12 +106,12 @@ def forward(self, x: Tensor): Parameters ---------- - x : torch.Tensor + x : Tensor Input tensor of shape (B, C, D, H, W) where D is the stack depth. Returns ------- - torch.Tensor + Tensor Output tensor with depth projected to channels, shape (B, C*D', H', W') where D' = D // kernel_size[0] after 3D convolution. """ @@ -111,7 +123,21 @@ def forward(self, x: Tensor): class StemDepthtoChannels(nn.Module): - """Stem with 3D convolution that maps depth to channels.""" + """Stem with 3D convolution that maps depth to channels. + + Parameters + ---------- + in_channels : int + Number of input channels. + in_stack_depth : int + Number of input stack depth. + in_channels_encoder : int + Number of input channels for the encoder. + stem_kernel_size : tuple[int, int, int] + Kernel size. + stem_stride : tuple[int, int, int] + Stride. + """ def __init__( self, @@ -134,7 +160,11 @@ def __init__( ) def compute_stem_channels( - self, in_stack_depth, stem_kernel_size, stem_stride_depth, in_channels_encoder + self, + in_stack_depth: int, + stem_kernel_size: tuple[int, int, int], + stem_stride_depth: int, + in_channels_encoder: int, ): """Compute required 3D stem output channels for encoder compatibility. @@ -199,7 +229,24 @@ class UNeXt2UpStage(nn.Module): low-resolution features with high-resolution skip connections for multi-scale feature fusion. - # TODO: MANUAL_REVIEW - ConvNeXt block integration with skip connections + Parameters + ---------- + in_channels : int + Number of input channels. + skip_channels : int + Number of skip channels. + out_channels : int + Number of output channels. + scale_factor : int + Scale factor. + mode : Literal["deconv", "pixelshuffle"] + Mode. "deconv" for deconvolution, "pixelshuffle" for pixel shuffle. + conv_blocks : int + Number of ConvNeXt blocks. + norm_name : str + Name of the normalization layer. + upsample_pre_conv : Literal["default"] | Callable | None + Upsample pre-convolution. """ def __init__( @@ -283,7 +330,18 @@ class PixelToVoxelHead(nn.Module): convolutions. Applies depth channel expansion and spatial upsampling to generate volumetric outputs from 2D feature representations. - # TODO: MANUAL_REVIEW - 2D to 3D reconstruction mechanism + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + out_stack_depth : int + Number of output stack depth. + expansion_ratio : int + Expansion ratio. + pool : bool + Whether to apply pooling in upsampling. """ def __init__( @@ -375,7 +433,20 @@ class UNeXt2Decoder(nn.Module): combining features from different encoder scales through skip connections. Each stage performs feature upsampling and refinement using ConvNeXt blocks. - # TODO: MANUAL_REVIEW - Multi-scale feature fusion strategy + Parameters + ---------- + num_channels : list[int] + Number of channels for each stage. + norm_name : str + Name of the normalization layer. + mode : Literal["deconv", "pixelshuffle"] + Mode. "deconv" for deconvolution, "pixelshuffle" for pixel shuffle. + conv_blocks : int + Number of ConvNeXt blocks. + strides : list[int] + Strides for each stage. + upsample_pre_conv : Literal["default"] | Callable | None + Upsample pre-convolution. """ def __init__( @@ -434,7 +505,36 @@ class UNeXt2(nn.Module): 2D multi-scale processing through ConvNeXt encoder-decoder, and 2D-to-3D reconstruction via specialized head modules. - # TODO: MANUAL_REVIEW - ConvNeXt transformer integration patterns + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + in_stack_depth : int + Number of input stack depth. + out_stack_depth : int, optional + Number of output stack depth. By default, None, it is the same as the input stack depth. + backbone : str + Backbone model. + pretrained : bool + Whether to use pretrained weights. + stem_kernel_size : tuple[int, int, int] + Kernel size. + decoder_mode : Literal["deconv", "pixelshuffle"] + Mode. "deconv" for deconvolution, "pixelshuffle" for pixel shuffle. + decoder_conv_blocks : int + Number of ConvNeXt blocks. By default, 2. + decoder_norm_layer : str, optional + Name of the normalization layer. By default, "instance". + decoder_upsample_pre_conv : bool, optional + Whether to use upsample pre-convolution. By default, False. + head_pool : bool, optional + Whether to apply pooling in upsampling. By default, False. + head_expansion_ratio : int, optional + Expansion ratio. By default, 4. + drop_path_rate : float, optional + Drop path rate. By default, 0.0. """ def __init__( @@ -442,7 +542,7 @@ def __init__( in_channels: int = 1, out_channels: int = 1, in_stack_depth: int = 5, - out_stack_depth: int = None, + out_stack_depth: int | None = None, backbone: str = "convnextv2_tiny", pretrained: bool = False, stem_kernel_size: tuple[int, int, int] = (5, 4, 4), diff --git a/viscy/utils/aux_utils.py b/viscy/utils/aux_utils.py index 1a6e37ddc..ea597f4f6 100644 --- a/viscy/utils/aux_utils.py +++ b/viscy/utils/aux_utils.py @@ -1,5 +1,7 @@ """Auxiliary utility functions.""" +from pathlib import Path + import iohub.ngff as ngff import yaml @@ -44,7 +46,7 @@ def _assert_unique_subset(subset, superset, name): def validate_metadata_indices( - zarr_dir, + zarr_dir: str | Path, time_ids=[], channel_ids=[], slice_ids=[], @@ -61,7 +63,7 @@ def validate_metadata_indices( Parameters ---------- - zarr_dir : str + zarr_dir : str | Path HCS-compatible zarr directory to validate indices against. time_ids : list, optional Check availability of these timepoints in image metadata, by default []. @@ -113,12 +115,12 @@ def validate_metadata_indices( return indices_metadata -def read_config(config_fname): +def read_config(config_fname: str | Path): """Read the config file in yml format. Parameters ---------- - config_fname : str + config_fname : str | Path Filename of config yaml with its full path. Returns diff --git a/viscy/utils/cli_utils.py b/viscy/utils/cli_utils.py index 92dcef04b..a88815172 100644 --- a/viscy/utils/cli_utils.py +++ b/viscy/utils/cli_utils.py @@ -3,23 +3,26 @@ import collections import os import re +from pathlib import Path import numpy as np import torch +from numpy.typing import NDArray from PIL import Image +from torch.utils.data import DataLoader -def unique_tags(directory): +def unique_tags(directory: str | Path) -> dict[str, int]: """Return list of unique nume tags from data directory. Parameters ---------- - directory : str + directory : str | Path Directory containing '.tif' files. Returns ------- - dict + dict[str, int] Dictionary of unique tags and their counts. Notes @@ -46,13 +49,18 @@ class MultiProcessProgressBar: Provides the ability to create & update a single progress bar for multi-depth multi-processed tasks by calling updates on a single object. + + Parameters + ---------- + total_updates : int + Total number of updates. """ - def __init__(self, total_updates): + def __init__(self, total_updates: int) -> None: self.dataloader = list(range(total_updates)) self.current = 0 - def tick(self, process): + def tick(self, process: str) -> None: """Update progress bar with current process status. Parameters @@ -64,14 +72,16 @@ def tick(self, process): show_progress_bar(self.dataloader, self.current, process) -def show_progress_bar(dataloader, current, process="training", interval=1): +def show_progress_bar( + dataloader: DataLoader, current: int, process: str = "training", interval: int = 1 +) -> None: """Print TensorFlow-like progress bar for batch processing. Written instead of using tqdm to allow for custom progress bar readouts. Parameters ---------- - dataloader : iterable + dataloader : DataLoader Dataloader currently being processed. current : int Current index in dataloader. @@ -105,7 +115,14 @@ def show_progress_bar(dataloader, current, process="training", interval=1): print(output_string) -def save_figure(data, save_folder, name, title=None, vmax=0, ext=".png"): +def save_figure( + data: NDArray | torch.Tensor, + save_folder: str | Path, + name: str, + title: str | None = None, + vmax: float = 0, + ext: str = ".png", +) -> None: """Save image data as PNG or JPEG figure. Saves .png or .jpeg figure of data to folder save_folder under 'name'. @@ -113,9 +130,9 @@ def save_figure(data, save_folder, name, title=None, vmax=0, ext=".png"): Parameters ---------- - data : numpy.ndarray or torch.Tensor + data : NDArray | torch.Tensor Input image/stack data to save in channels_first format. - save_folder : str + save_folder : str | Path Global path to folder where data is saved. name : str Name of data, no extension specified. @@ -125,12 +142,17 @@ def save_figure(data, save_folder, name, title=None, vmax=0, ext=".png"): Value to normalize figure to, by default 0 (uses data max). ext : str, optional Image save file extension, by default ".png". + + Raises + ------ + AttributeError + If data is not a torch tensor or numpy array. """ assert len(data.shape) == 3, f"'{len(data.shape)}d' data must be 3-dimensional" if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() - elif not isinstance(data, np.ndarray): + elif not isinstance(data, NDArray): raise AttributeError( f"'data' of type {type(data)} must be torch tensor or numpy array." ) diff --git a/viscy/utils/image_utils.py b/viscy/utils/image_utils.py index 2462e3e15..9f24631bd 100644 --- a/viscy/utils/image_utils.py +++ b/viscy/utils/image_utils.py @@ -19,7 +19,7 @@ def im_bit_convert( Parameters ---------- - im : array_like + im : ArrayLike Input image to convert. bit : int, optional Target bit depth (8 or 16), by default 16. @@ -31,7 +31,7 @@ def im_bit_convert( Returns ------- - np.array + NDArray Image converted to specified bit depth. """ im = im.astype( @@ -61,7 +61,7 @@ def im_adjust(img: ArrayLike, tol: int | float = 1, bit: int = 8) -> NDArray[Any Parameters ---------- - img : array_like + img : ArrayLike Input image to adjust. tol : int or float, optional Tolerance percentile for contrast stretching, by default 1. @@ -70,7 +70,7 @@ def im_adjust(img: ArrayLike, tol: int | float = 1, bit: int = 8) -> NDArray[Any Returns ------- - np.array + NDArray Contrast-adjusted image in specified bit depth. """ limit = np.percentile(img, [tol, 100 - tol]) @@ -87,18 +87,18 @@ def grid_sample_pixel_values( Parameters ---------- - im : np.array + im : NDArray 2D image to sample from. grid_spacing : int Spacing of the grid points. Returns ------- - row_ids : np.array + row_ids : NDArray Row indices of the grid points. - col_ids : np.array + col_ids : NDArray Column indices of the grid points. - sample_values : np.array + sample_values : NDArray Sampled pixel values at grid points. """ im_shape = im.shape @@ -134,9 +134,9 @@ def preprocess_image( Parameters ---------- - im : np.array + im : ArrayLike Input image or image stack. - hist_clip_limits : tuple, optional + hist_clip_limits : tuple[float, float], optional Percentile histogram clipping limits (min_percentile, max_percentile), by default None. is_mask : bool, optional @@ -150,7 +150,7 @@ def preprocess_image( Returns ------- - np.array + NDArray Preprocessed image. """ # remove singular dimension for 3D images diff --git a/viscy/utils/log_images.py b/viscy/utils/log_images.py index 77bee649d..3d26a5b7b 100644 --- a/viscy/utils/log_images.py +++ b/viscy/utils/log_images.py @@ -4,13 +4,14 @@ import numpy as np from matplotlib.pyplot import get_cmap +from numpy.typing import NDArray from skimage.exposure import rescale_intensity from torch import Tensor def detach_sample( imgs: Sequence[Tensor], log_samples_per_batch: int -) -> list[list[np.ndarray]]: +) -> list[list[NDArray]]: """Detach example images from the batch and convert them to numpy arrays. Parameters @@ -22,7 +23,7 @@ def detach_sample( Returns ------- - list[list[np.ndarray]] + list[list[NDArray]] Grid of example images. Rows are samples, columns are channels. """ @@ -38,21 +39,19 @@ def detach_sample( return samples -def render_images( - imgs: Sequence[Sequence[np.ndarray]], cmaps: list[str] = [] -) -> np.ndarray: +def render_images(imgs: Sequence[Sequence[NDArray]], cmaps: list[str] = []) -> NDArray: """Render images in a grid. Parameters ---------- - imgs : Sequence[Sequence[np.ndarray]] + imgs : Sequence[Sequence[NDArray]] Grid of images to render, output of `detach_sample`. cmaps : list[str], optional Colormaps for each column, by default [] Returns ------- - np.ndarray + NDArray Rendered RGB images grid. """ images_grid = [] diff --git a/viscy/utils/logging.py b/viscy/utils/logging.py index 3aa046863..e00821d63 100644 --- a/viscy/utils/logging.py +++ b/viscy/utils/logging.py @@ -12,19 +12,25 @@ def log_feature( feature_map: torch.Tensor, name: str, log_save_folder: str, debug_mode: bool ) -> None: - """ - Create visual feature map logs for debugging deep learning models. + """Create visual feature map logs for debugging deep learning models. If debug_mode is enabled, creates a visual of the given feature map and saves it at - 'log_save_folder'. If no log_save_folder specified, saves relative to working directory with timestamp. + 'log_save_folder'. If no log_save_folder specified, saves relative to working + directory with timestamp. Currently only saving in working directory is supported. - This is meant to be an analysis tool, - and results should not be saved permanently. + This is meant to be an analysis tool, and results should not be saved permanently. - :param torch.tensor feature_map: feature map to create visualization log of - :param str name: string - :param str log_save_folder + Parameters + ---------- + feature_map : torch.Tensor + Feature map to create visualization log of. + name : str + Name identifier for the feature map visualization. + log_save_folder : str + Directory path for saving the visualization output. + debug_mode : bool + Whether to enable debug mode visualization logging. """ try: if debug_mode: @@ -100,25 +106,28 @@ def __init__( grid_width: int = 0, normalize_by_grid: bool = False, ) -> None: - """ - Logger object for handling logging feature maps inside network architectures. - - Saves each 2d slice of a feature map in either a single grid per feature map - stack or a directory tree of labeled slices. - - By default saves images into grid. - - :param str save_folder: output directory - :param bool full_batch: if true, log all sample in batch (warning slow!), - defaults to False - :param bool save_as_grid: if true feature maps are to be saved as a grid - containing all channels, else saved individually, - defaults to True - :param int grid_width: desired width of grid if save_as_grid, by default - 1/4 the number of channels, defaults to 0 - :param bool normalize_by_grid: if true, images saved in grid are normalized - to brightest pixel in entire grid, defaults to False - + """Initialize logger for handling feature map visualization in neural networks. + + Saves each 2D slice of a feature map in either a single grid per feature map + stack or a directory tree of labeled slices. By default saves images into grid. + + Parameters + ---------- + save_folder : str + Output directory for saving visualization files. + spatial_dims : int, optional + Number of spatial dimensions in feature tensors, by default 3. + full_batch : bool, optional + If true, log all samples in batch (warning: slow!), by default False. + save_as_grid : bool, optional + If true, feature maps are saved as a grid containing all channels, + else saved individually, by default True. + grid_width : int, optional + Desired width of grid if save_as_grid. If 0, defaults to 1/4 the + number of channels, by default 0. + normalize_by_grid : bool, optional + If true, images saved in grid are normalized to brightest pixel in + entire grid, by default False. """ self.save_folder = save_folder self.spatial_dims = spatial_dims @@ -136,22 +145,25 @@ def log_feature_map( dim_names: list[str] | None = None, vmax: float = 0, ) -> None: - """ - Create a log of figures for the given feature map tensor at 'save_folder'. - - Log is saved as images of feature maps in nested directory tree. - - By default _assumes that batch dimension is the first dimension_, and - only logs the first sample in the batch, for performance reasons. - - Feature map logs cannot overwrite. - - :param torch.Tensor feature_map: feature map to log (typically 5d tensor) - :parapm str feature_name: name of feature (will be used as dir name) - :param list dim_names: names of each dimension, by default just numbers - :param int spatial_dims: number of spatial dims, defaults to 3 - :param float vmax: maximum intensity to normalize figures by, by default - (if given 0) does relative normalization + """Create a log of figures for the given feature map tensor. + + Log is saved as images of feature maps in nested directory tree at save_folder. + + By default assumes that batch dimension is the first dimension, and only logs + the first sample in the batch for performance reasons. Feature map logs cannot + overwrite existing files. + + Parameters + ---------- + feature_map : torch.Tensor + Feature map to log, typically 5D tensor (BCDHW or BCTHW). + feature_name : str + Name of feature, used as directory name for organizing outputs. + dim_names : list[str] | None, optional + Names of each non-spatial dimension, by default just numbers. + vmax : float, optional + Maximum intensity to normalize figures by. If 0, uses relative + normalization, by default 0. """ # take tensor off of gpu and detach gradient feature_map = feature_map.detach().cpu() @@ -159,7 +171,7 @@ def log_feature_map( # handle dim names num_dims = len(feature_map.shape) if dim_names is None: - dim_names = ["dim_" + str(i) for i in range(len(num_dims))] + dim_names = ["dim_" + str(i) for i in range(num_dims)] else: assert len(dim_names) + self.spatial_dims == num_dims, ( "dim_names must be same length as nonspatial tensor dim length" @@ -183,17 +195,26 @@ def map_feature_dims( vmax: float = 0, depth: int = 0, ) -> None: - """ - Recursive directory creation for organizing feature map logs - - If save_as_grid, will compile 'channels' (assumed to be last - non-spatial dimension) into a single large image grid before saving. - - :param numpy.ndarray feature_map: see name - :param str save_dir: see name - :param bool save_as_grid: if true, saves images as channel grid - :param float vmax: maximum intensity to normalize figures by - :param int depth: recursion counter. depth in dimensions + """Recursively create directory structure for organizing feature map logs. + + If save_as_grid is True, compiles 'channels' (assumed to be last non-spatial + dimension) into a single large image grid before saving. + + Parameters + ---------- + feature_map : torch.Tensor + Feature tensor to process and save. + save_as_grid : bool + If true, saves images as channel grid layout. + vmax : float, optional + Maximum intensity to normalize figures by, by default 0. + depth : int, optional + Recursion counter tracking depth in tensor dimensions, by default 0. + + Raises + ------ + AttributeError + If the feature map has an invalid number of dimensions. """ for i in range(feature_map.shape[0]): if len(feature_map.shape) == 3: @@ -309,16 +330,26 @@ def interleave_bars( pixel_width: int = 3, value: float = 0, ) -> list[torch.Tensor]: - """ - Interleave separator bars between tensors to improve grid visualization. - - Takes list of 2d torch tensors and interleaves bars to improve - grid visualization quality. Assumes arrays are all of the same shape. - - :param list grid_arrays: list of tensors to place bars between - :param int axis: axis on which to interleave bars (0 or 1) - :param int pixel_width: width of bar, defaults to 3 - :param int value: value of bar pixels, defaults to 0 + """Interleave separator bars between tensors to improve grid visualization. + + Takes list of 2D torch tensors and interleaves bars to improve grid + visualization quality. Assumes arrays are all of the same shape. + + Parameters + ---------- + arrays : list[torch.Tensor] + List of tensors to place separator bars between. + axis : int + Axis on which to interleave bars (0 or 1). + pixel_width : int, optional + Width of separator bar in pixels, by default 3. + value : float, optional + Pixel value for separator bars, by default 0. + + Returns + ------- + list[torch.Tensor] + List of tensors with separator bars interleaved for grid visualization. """ shape_match_axis = abs(axis - 1) length = arrays[0].shape[shape_match_axis] diff --git a/viscy/utils/masks.py b/viscy/utils/masks.py index 7366fb53e..f000fce94 100644 --- a/viscy/utils/masks.py +++ b/viscy/utils/masks.py @@ -21,7 +21,7 @@ def create_otsu_mask( Parameters ---------- - input_image : np.array + input_image : NDArray Generate masks from this 3D image. sigma : float, optional Gaussian blur standard deviation, increase in value increases blur, @@ -29,7 +29,7 @@ def create_otsu_mask( Returns ------- - np.array + NDArray Volume mask of input_image, 3D binary array. """ input_sz = input_image.shape @@ -52,7 +52,7 @@ def create_membrane_mask( Parameters ---------- - input_image : np.array + input_image : NDArray Generate masks from this image. str_elem_size : int, optional Size of the laplacian filter used for contrast enhancement, odd number. @@ -68,7 +68,7 @@ def create_membrane_mask( Returns ------- - np.array + NDArray Binary mask of input_image. """ input_image_blur = gaussian(input_image, sigma=sigma) @@ -100,7 +100,7 @@ def get_unimodal_threshold(input_image: NDArray[Any]) -> float: Parameters ---------- - input_image : np.array + input_image : NDArray Generate mask for this image. Returns @@ -150,7 +150,7 @@ def create_unimodal_mask( Parameters ---------- - input_image : np.array + input_image : NDArray Generate masks from this image. str_elem_size : int, optional Size of the structuring element, typically 3 or 5, by default 3. @@ -159,7 +159,7 @@ def create_unimodal_mask( Returns ------- - np.array + NDArray Binary mask of input_image. """ input_image = gaussian(input_image, sigma=sigma) @@ -189,7 +189,7 @@ def get_unet_border_weight_map( Parameters ---------- - annotation : np.array + annotation : NDArray A 2D array of shape (image_height, image_width) containing annotation with each class labeled as an integer. w0 : int, optional @@ -202,7 +202,7 @@ def get_unet_border_weight_map( Returns ------- - np.array + NDArray Weight map for borders as specified in U-Net paper. """ # if there is only one label, zero return the array as is diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index 22a28754e..9cb753f78 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -1,33 +1,39 @@ import os import sys +from pathlib import Path import iohub.ngff as ngff import numpy as np import pandas as pd +from numpy.typing import NDArray import viscy.utils.mp_utils as mp_utils from viscy.utils.cli_utils import show_progress_bar def write_meta_field(position: ngff.Position, metadata, field_name, subfield_name): - """Write 'metadata' to position's plate-level or FOV level .zattrs metadata. + """Write metadata to position's plate-level or FOV level .zattrs metadata. - Write 'metadata' to position's plate-level or FOV level .zattrs metadata by either - creating a new field (field_name) according to 'metadata', or updating the metadata - to an existing field if found, - or concatenating the metadata from different channels. + Write metadata to position's plate-level or FOV level .zattrs metadata by either + creating a new field (field_name) according to metadata, or updating the metadata + to an existing field if found, or concatenating the metadata from different channels. - Assumes that the zarr store group given follows the OMG-NGFF HCS - format as specified here: - https://ngff.openmicroscopy.org/latest/#hcs-layout + Assumes that the zarr store group given follows the OME-NGFF HCS + format as specified here: https://ngff.openmicroscopy.org/latest/#hcs-layout Warning: Dangerous. Writing metadata fields above the image-level of - an HCS hierarchy can break HCS compatibility - - :param Position zarr_dir: NGFF position node object - :param dict metadata: metadata dictionary to write to JSON .zattrs - :param str subfield_name: name of subfield inside the the main field - (values for different channels) + an HCS hierarchy can break HCS compatibility. + + Parameters + ---------- + position : ngff.Position + NGFF position node object. + metadata : dict + Metadata dictionary to write to JSON .zattrs. + field_name : str + Name of the main metadata field. + subfield_name : str + Name of subfield inside the main field (values for different channels). """ if field_name in position.zattrs: if subfield_name in position.zattrs[field_name]: @@ -47,10 +53,10 @@ def write_meta_field(position: ngff.Position, metadata, field_name, subfield_nam def generate_normalization_metadata( - zarr_dir, - num_workers=4, - channel_ids=-1, - grid_spacing=32, + zarr_dir: str | Path, + num_workers: int = 4, + channel_ids: list[int] | int = -1, + grid_spacing: int = 32, ): """Generate pixel intensity metadata for on-the-fly normalization. @@ -65,17 +71,22 @@ def generate_normalization_metadata( channel_idx : { dataset_statistics: dataset level normalization values (positive float), fov_statistics: field-of-view level normalization values (positive float) - }, - . - . - . + } } - :param str zarr_dir: path to zarr store directory containing dataset. - :param int num_workers: number of cpu workers for multiprocessing, defaults to 4 - :param list/int channel_ids: indices of channels to process in dataset arrays, - by default calculates all - :param int grid_spacing: distance between points in sampling grid + Warning: Dangerous. Writing metadata fields above the image-level of + an HCS hierarchy can break HCS compatibility. + + Parameters + ---------- + zarr_dir : str + Path to zarr store directory containing dataset. + num_workers : int, optional + Number of CPU workers for multiprocessing, by default 4. + channel_ids : list[int] | int, optional + Indices of channels to process in dataset arrays, by default -1 (all channels). + grid_spacing : int, optional + Distance between points in sampling grid, by default 32. """ plate = ngff.open_ome_zarr(zarr_dir, mode="r+") position_map = list(plate.positions()) @@ -172,12 +183,22 @@ def generate_normalization_metadata( print(f"Dataset-level statistics: {final_dataset_stats}") -def compute_normalization_stats(image_data, grid_spacing=32): +def compute_normalization_stats( + image_data: NDArray, grid_spacing: int = 32 +) -> dict[str, float]: """Compute normalization statistics from image data using grid sampling. - :param np.array image_data: 3D or 4D image array (z, y, x) or (t, z, y, x) - :param int grid_spacing: spacing between grid points for sampling - :return dict: dictionary with median and IQR statistics + Parameters + ---------- + image_data : np.ndarray + 3D or 4D image array of shape (z, y, x) or (t, z, y, x). + grid_spacing : int, optional + Spacing between grid points for sampling, by default 32. + + Returns + ------- + dict[str, float] + Dictionary with median and IQR statistics for normalization. """ # Handle different input shapes if image_data.ndim == 4: diff --git a/viscy/utils/mp_utils.py b/viscy/utils/mp_utils.py index ce04e8093..686dade41 100644 --- a/viscy/utils/mp_utils.py +++ b/viscy/utils/mp_utils.py @@ -6,6 +6,7 @@ import numpy as np import scipy.stats import zarr +from numpy.typing import NDArray import viscy.utils.image_utils as image_utils import viscy.utils.masks as mask_utils @@ -61,7 +62,7 @@ def mp_create_and_write_mask(fn_args: list[tuple[Any, ...]], workers: int) -> li def add_channel( position: ngff.Position, - new_channel_array: np.ndarray, + new_channel_array: NDArray, new_channel_name: str, overwrite_ok: bool = False, ) -> None: @@ -82,7 +83,7 @@ def add_channel( ---------- position : ngff.Position NGFF position node object. - new_channel_array : np.ndarray + new_channel_array : NDArray Array to add as new channel with matching dimensions (except channel dim) and dtype. new_channel_name : str @@ -229,7 +230,7 @@ def get_mask_slice( channel_index: int, mask_type: str, structure_elem_radius: int, -) -> np.ndarray: +) -> NDArray: """Compute mask for a single image slice. Given a set of indices, mask type, and structuring element, @@ -251,7 +252,7 @@ def get_mask_slice( Returns ------- - np.ndarray + NDArray 2D mask for this slice. """ # read and correct/preprocess slice @@ -280,9 +281,17 @@ def mp_get_val_stats(fn_args: list[Any], workers: int) -> list[dict[str, float]] """ Compute statistics of numpy arrays with multiprocessing - :param list of tuple fn_args: list with tuples of function arguments - :param int workers: max number of workers - :return: list of returned df from get_im_stats + Parameters + ---------- + fn_args : list of tuple + List with tuples of function arguments. + workers : int + Max number of workers. + + Returns + ------- + list[dict[str, float]] + List of returned df from get_im_stats. """ with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions @@ -303,7 +312,7 @@ def get_val_stats(sample_values: list[float]) -> dict[str, float]: Returns ------- - dict + dict[str, float] Dictionary with intensity data for image. """ meta_row = { @@ -320,9 +329,17 @@ def mp_sample_im_pixels( ) -> list[list[Any]]: """Read and computes statistics of images with multiprocessing - :param list of tuple fn_args: list with tuples of function arguments - :param int workers: max number of workers - :return: list of paths and corresponding returned df from get_im_stats + Parameters + ---------- + fn_args : list[tuple[Any, ...]] + List with tuples of function arguments. + workers : int + Max number of workers. + + Returns + ------- + list[list[Any]] + List of paths and corresponding returned df from get_im_stats. """ with ProcessPoolExecutor(workers) as ex: # can't use map directly as it works only with single arg functions @@ -334,7 +351,7 @@ def sample_im_pixels( position: ngff.Position, grid_spacing: int, channel: int, -) -> tuple[ngff.Position, np.ndarray]: +) -> tuple[ngff.Position, NDArray]: # TODO move out of mp utils into normalization utils """Read and compute statistics of images for each point in a grid. @@ -354,8 +371,8 @@ def sample_im_pixels( Returns ------- - list - Dicts with intensity data for each grid point. + tuple[ngff.Position, NDArray] + Position and array with intensity data for each grid point. """ image_zarr = position.data diff --git a/viscy/utils/normalize.py b/viscy/utils/normalize.py index 6d12baba3..e86eae818 100644 --- a/viscy/utils/normalize.py +++ b/viscy/utils/normalize.py @@ -9,7 +9,7 @@ def zscore( - input_image: ArrayLike, im_mean: float | None = None, im_std: float | None = None + input_image: NDArray, im_mean: float | None = None, im_std: float | None = None ) -> NDArray[Any]: """Perform z-score normalization. @@ -17,7 +17,7 @@ def zscore( Parameters ---------- - input_image : np.array + input_image : NDArray Input image for intensity normalization. im_mean : float, optional Image mean, by default None. @@ -26,7 +26,7 @@ def zscore( Returns ------- - np.array + NDArray Z-score normalized image. """ if not im_mean: @@ -37,16 +37,14 @@ def zscore( return norm_img -def unzscore( - im_norm: ArrayLike, zscore_median: float, zscore_iqr: float -) -> NDArray[Any]: +def unzscore(im_norm: NDArray, zscore_median: float, zscore_iqr: float) -> NDArray[Any]: """Revert z-score normalization applied during preprocessing. Necessary before computing SSIM. Parameters ---------- - im_norm : array_like + im_norm : NDArray Normalized image for un-zscore. zscore_median : float Image median. @@ -55,7 +53,7 @@ def unzscore( Returns ------- - array_like + NDArray Image at its original scale. """ im = im_norm * (zscore_iqr + sys.float_info.epsilon) + zscore_median @@ -63,7 +61,7 @@ def unzscore( def hist_clipping( - input_image: ArrayLike, + input_image: NDArray, min_percentile: int | float = 2, max_percentile: int | float = 98, ) -> NDArray[Any]: @@ -73,7 +71,7 @@ def hist_clipping( Parameters ---------- - input_image : np.array + input_image : NDArray Input image for intensity normalization. min_percentile : int or float, optional Min intensity percentile, by default 2. @@ -82,7 +80,7 @@ def hist_clipping( Returns ------- - np.array + NDArray Intensity clipped and rescaled image. """ assert (min_percentile < max_percentile) and max_percentile <= 100 @@ -92,10 +90,10 @@ def hist_clipping( def hist_adapteq_2D( - input_image: NDArray[Any], + input_image: NDArray, kernel_size: int | list[int] | tuple[int, ...] | None = None, clip_limit: float | None = None, -) -> NDArray[Any]: +) -> NDArray: """Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) on 2D images. skimage.exposure.equalize_adapthist works only for 2D. Extend to 3D or use @@ -103,7 +101,7 @@ def hist_adapteq_2D( Parameters ---------- - input_image : np.array + input_image : NDArray Input image for intensity normalization. kernel_size : int or list, optional Neighbourhood to be used for histogram equalization. If None, use default @@ -116,7 +114,7 @@ def hist_adapteq_2D( Returns ------- - np.array + NDArray Adaptive histogram equalized image. """ nrows, ncols = input_image.shape diff --git a/viscy/utils/slurm_utils.py b/viscy/utils/slurm_utils.py index 9cfafb84b..c943b8b11 100644 --- a/viscy/utils/slurm_utils.py +++ b/viscy/utils/slurm_utils.py @@ -36,7 +36,8 @@ def calculate_dataloader_settings( Returns ------- - dict: Recommended settings for DataLoader + dict: + Dictionary with recommended settings for DataLoader """ # Get system resources if not provided if available_ram_gb is None: From 24a540f917686322f52d23c19b4f12536125a8a6 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Thu, 11 Sep 2025 10:16:23 -0700 Subject: [PATCH 03/13] updated __init__ for classes --- ruff.toml | 4 +- viscy/cli.py | 2 +- viscy/data/__init__.py | 1 + viscy/data/distributed.py | 2 +- viscy/data/gpu_aug.py | 2 + viscy/data/hcs.py | 7 +- viscy/data/livecell.py | 2 + viscy/data/mmap_cache.py | 2 + viscy/data/select.py | 2 + viscy/data/triplet.py | 188 +++++++++--------- viscy/preprocessing/generate_masks.py | 43 ++-- viscy/preprocessing/pixel_ratio.py | 4 +- viscy/preprocessing/precompute.py | 2 +- viscy/representation/classification.py | 34 ++-- viscy/representation/contrastive.py | 2 + viscy/representation/embedding_writer.py | 2 + viscy/representation/engine.py | 2 + viscy/representation/evaluation/distance.py | 2 + viscy/representation/evaluation/feature.py | 2 + .../evaluation/visualization.py | 2 + viscy/representation/multi_modal.py | 2 + viscy/trainer.py | 2 + viscy/transforms/__init__.py | 2 + viscy/transforms/_redef.py | 60 ++++++ viscy/translation/engine.py | 8 +- viscy/unet/__init__.py | 1 + viscy/unet/networks/Unet25D.py | 79 ++++---- viscy/unet/networks/Unet2D.py | 65 +++--- viscy/unet/networks/__init__.py | 1 + viscy/unet/networks/layers/ConvBlock3D.py | 8 +- viscy/unet/networks/layers/__init__.py | 1 + viscy/unet/networks/unext2.py | 6 +- viscy/utils/__init__.py | 2 +- viscy/utils/cli_utils.py | 2 +- viscy/utils/logging.py | 58 +++--- viscy/utils/masks.py | 2 + viscy/utils/meta_utils.py | 2 + viscy/utils/mp_utils.py | 7 +- viscy/utils/slurm_utils.py | 2 + 39 files changed, 360 insertions(+), 257 deletions(-) diff --git a/ruff.toml b/ruff.toml index c3e0dbc8a..7e3f1e007 100644 --- a/ruff.toml +++ b/ruff.toml @@ -23,8 +23,8 @@ select = [ "I", # isort ] ignore = [ - "D100", # Missing docstring in public module - "D104", # Missing docstring in public package + # "D100", # Missing docstring in public module + # "D104", # Missing docstring in public package "D105", # __magic__ methods are often self-explanatory, allow missing docstrings "D107", # Missing docstring in __init__ # Disable one in each pair of mutually incompatible rules diff --git a/viscy/cli.py b/viscy/cli.py index 65f4137ed..ad6537e98 100644 --- a/viscy/cli.py +++ b/viscy/cli.py @@ -15,7 +15,7 @@ class VisCyCLI(LightningCLI): - """Extending lightning CLI arguments and defualts.""" + """Extending Lightning CLI arguments and defaults for VisCy.""" @staticmethod def subcommands() -> dict[str, set[str]]: diff --git a/viscy/data/__init__.py b/viscy/data/__init__.py index e69de29bb..f8b10be0b 100644 --- a/viscy/data/__init__.py +++ b/viscy/data/__init__.py @@ -0,0 +1 @@ +"""VisCy data loading and preprocessing modules.""" diff --git a/viscy/data/distributed.py b/viscy/data/distributed.py index badce1ab3..8da588fea 100644 --- a/viscy/data/distributed.py +++ b/viscy/data/distributed.py @@ -34,7 +34,7 @@ def _sharded_randperm(self, max_size: int, generator: Generator) -> list[int]: return indices.tolist() def __iter__(self): - """Modified __iter__ method to shard data across distributed ranks.""" + """Iterate through sharded data across distributed ranks.""" max_size = len(self.dataset) # type: ignore[arg-type] if self.shuffle: # deterministically shuffle based on epoch and seed diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index abca552e5..f5e8ecb4e 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -1,3 +1,5 @@ +"""GPU-accelerated data augmentation modules for microscopy ML training.""" + from __future__ import annotations from abc import ABC, abstractmethod diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index c5501c385..4cec0033b 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -237,6 +237,7 @@ def _read_img_window( return torch.from_numpy(data).unbind(dim=1), HCSStackIndex(img.name, t, z) def __len__(self) -> int: + """Return total number of sliding windows across all FOVs.""" return self._max_window # TODO: refactor to a top level function @@ -255,6 +256,7 @@ def _stack_channels( ] def __getitem__(self, index: int) -> Sample: + """Get sliding window sample by index.""" img, tz, norm_meta = self._find_window(index) ch_names = self.channels["source"].copy() ch_idx = self.source_ch_idx.copy() @@ -330,6 +332,7 @@ def __init__( _logger.info(str(self.masks)) def __getitem__(self, index: int) -> Sample: + """Get sample with ground truth mask if available.""" sample = super().__getitem__(index) img_name, t_idx, z_idx = sample["index"] position_name = int(img_name.split("/")[-2]) @@ -623,7 +626,7 @@ def _setup_predict( ) def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample: - """Removes redundant Z slices if the target is 2D to save VRAM.""" + """Remove redundant Z slices if the target is 2D to save VRAM.""" predicting = False if self.trainer: if self.trainer.predicting: @@ -738,7 +741,7 @@ def _fit_transform(self) -> tuple[Compose, Compose]: return train_transform, val_transform def _train_transform(self) -> list[Callable]: - """Setup training augmentations. + """Set up training augmentations. Check input values and parse the number of Z slices and patches to sample per stack. diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index d7134fe87..50e1c201d 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -1,3 +1,5 @@ +"""LiveCell dataset implementation for cell segmentation benchmarking.""" + from __future__ import annotations import json diff --git a/viscy/data/mmap_cache.py b/viscy/data/mmap_cache.py index b3cf427f4..10e967c10 100644 --- a/viscy/data/mmap_cache.py +++ b/viscy/data/mmap_cache.py @@ -1,3 +1,5 @@ +"""Memory-mapped caching for OME-Zarr data with efficient disk I/O.""" + from __future__ import annotations import os diff --git a/viscy/data/select.py b/viscy/data/select.py index 4a0e4539b..509ac2440 100644 --- a/viscy/data/select.py +++ b/viscy/data/select.py @@ -1,3 +1,5 @@ +"""Well and field-of-view selection utilities for HCS datasets.""" + from collections.abc import Generator from iohub.ngff.nodes import Plate, Position, Well diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 1392f83f7..3fbdb3be6 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -1,3 +1,5 @@ +"""Triplet sampling for contrastive learning on tracked cell data.""" + import logging from collections.abc import Sequence from pathlib import Path @@ -72,6 +74,43 @@ class TripletDataset(Dataset): Generates anchor, positive, and negative triplets from tracked cell patches for contrastive learning. Supports temporal sampling with configurable time intervals. + + Parameters + ---------- + positions : list[Position] + OME-Zarr images with consistent channel order + tracks_tables : list[pd.DataFrame] + Data frames containing ultrack results + channel_names : list[str] + Input channel names + initial_yx_patch_size : tuple[int, int] + YX size of the initially sampled image patch before augmentation + z_range : slice + Range of Z-slices + anchor_transform : DictTransform | None, optional + Transforms applied to the anchor sample, by default None + positive_transform : DictTransform | None, optional + Transforms applied to the positve sample, by default None + negative_transform : DictTransform | None, optional + Transforms applied to the negative sample, by default None + fit : bool, optional + Fitting mode in which the full triplet will be sampled, + only sample anchor if ``False``, by default True + predict_cells : bool, optional + Only predict on selected cells, by default False + include_fov_names : list[str] | None, optional + Only predict on selected FOVs, by default None + include_track_ids : list[int] | None, optional + Only predict on selected track IDs, by default None + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, + by default "any" + (sample negative from another track any time point + and use the augmented anchor patch as positive) + return_negative : bool, optional + Whether to return the negative sample during the fit stage + (can be set to False when using a loss function like NT-Xent), + by default True """ def __init__( @@ -91,45 +130,6 @@ def __init__( time_interval: Literal["any"] | int = "any", return_negative: bool = True, ) -> None: - """Dataset for triplet sampling of cells based on tracking. - - Parameters - ---------- - positions : list[Position] - OME-Zarr images with consistent channel order - tracks_tables : list[pd.DataFrame] - Data frames containing ultrack results - channel_names : list[str] - Input channel names - initial_yx_patch_size : tuple[int, int] - YX size of the initially sampled image patch before augmentation - z_range : slice - Range of Z-slices - anchor_transform : DictTransform | None, optional - Transforms applied to the anchor sample, by default None - positive_transform : DictTransform | None, optional - Transforms applied to the positve sample, by default None - negative_transform : DictTransform | None, optional - Transforms applied to the negative sample, by default None - fit : bool, optional - Fitting mode in which the full triplet will be sampled, - only sample anchor if ``False``, by default True - predict_cells : bool, optional - Only predict on selected cells, by default False - include_fov_names : list[str] | None, optional - Only predict on selected FOVs, by default None - include_track_ids : list[int] | None, optional - Only predict on selected track IDs, by default None - time_interval : Literal["any"] | int, optional - Future time interval to sample positive and anchor from, - by default "any" - (sample negative from another track any time point - and use the augmented anchor patch as positive) - return_negative : bool, optional - Whether to return the negative sample during the fit stage - (can be set to False when using a loss function like NT-Xent), - by default True - """ self.positions = positions self.channel_names = channel_names self.channel_indices = [ @@ -212,6 +212,7 @@ def _specific_cells(self, tracks: pd.DataFrame) -> pd.DataFrame: return specific_tracks.reset_index(drop=True) def __len__(self) -> int: + """Return number of valid anchor samples.""" return len(self.valid_anchors) def _sample_positives(self, anchor_rows: pd.DataFrame) -> pd.DataFrame: @@ -288,6 +289,7 @@ def _slice_patches(self, track_rows: pd.DataFrame): return torch.from_numpy(np.stack(results, axis=0)), norms def __getitems__(self, indices: list[int]) -> list[TripletSample]: + """Get batched triplet samples for efficient data loading.""" anchor_rows = self.valid_anchors.iloc[indices] anchor_patches, anchor_norms = self._slice_patches(anchor_rows) if self.fit: @@ -350,6 +352,59 @@ class TripletDataModule(HCSDataModule): Provides train, validation, and prediction dataloaders for contrastive learning on cell tracking data. Supports configurable time intervals and spatial patch sampling. + + Parameters + ---------- + data_path : str | Path + Image dataset path + tracks_path : str | Path + Tracks labels dataset path + source_channel : str | Sequence[str] + List of input channel names + z_range : tuple[int, int] + Range of valid z-slices + initial_yx_patch_size : tuple[int, int], optional + XY size of the initially sampled image patch, by default (512, 512) + final_yx_patch_size : tuple[int, int], optional + Output patch size, by default (224, 224) + split_ratio : float, optional + Ratio of training samples, by default 0.8 + batch_size : int, optional + Batch size, by default 16 + num_workers : int, optional + Number of data-loading workers, by default 8 + normalizations : list[MapTransform], optional + Normalization transforms, by default [] + augmentations : list[MapTransform], optional + Augmentation transforms, by default [] + caching : bool, optional + Whether to cache the dataset, by default False + fit_include_wells : list[str], optional + Only include these wells for fitting, by default None + fit_exclude_fovs : list[str], optional + Exclude these FOVs for fitting, by default None + predict_cells : bool, optional + Only predict for selected cells, by default False + include_fov_names : list[str] | None, optional + Only predict for selected FOVs, by default None + include_track_ids : list[int] | None, optional + Only predict for selected tracks, by default None + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, + "any" means sampling negative from another track any time point + and using the augmented anchor patch as positive), by default "any" + return_negative : bool, optional + Whether to return the negative sample during the fit stage + (can be set to False when using a loss function like NT-Xent), + by default True + persistent_workers : bool, optional + Whether to keep worker processes alive between iterations, by default False + prefetch_factor : int | None, optional + Number of batches loaded in advance by each worker, by default None + pin_memory : bool, optional + Whether to pin memory in CPU for faster GPU transfer, by default False + z_window_size : int, optional + Size of the final Z window, by default None (inferred from z_range) """ def __init__( @@ -378,61 +433,6 @@ def __init__( pin_memory: bool = False, z_window_size: int | None = None, ): - """Lightning data module for triplet sampling of patches. - - Parameters - ---------- - data_path : str | Path - Image dataset path - tracks_path : str | Path - Tracks labels dataset path - source_channel : str | Sequence[str] - List of input channel names - z_range : tuple[int, int] - Range of valid z-slices - initial_yx_patch_size : tuple[int, int], optional - XY size of the initially sampled image patch, by default (512, 512) - final_yx_patch_size : tuple[int, int], optional - Output patch size, by default (224, 224) - split_ratio : float, optional - Ratio of training samples, by default 0.8 - batch_size : int, optional - Batch size, by default 16 - num_workers : int, optional - Number of data-loading workers, by default 8 - normalizations : list[MapTransform], optional - Normalization transforms, by default [] - augmentations : list[MapTransform], optional - Augmentation transforms, by default [] - caching : bool, optional - Whether to cache the dataset, by default False - fit_include_wells : list[str], optional - Only include these wells for fitting, by default None - fit_exclude_fovs : list[str], optional - Exclude these FOVs for fitting, by default None - predict_cells : bool, optional - Only predict for selected cells, by default False - include_fov_names : list[str] | None, optional - Only predict for selected FOVs, by default None - include_track_ids : list[int] | None, optional - Only predict for selected tracks, by default None - time_interval : Literal["any"] | int, optional - Future time interval to sample positive and anchor from, - "any" means sampling negative from another track any time point - and using the augmented anchor patch as positive), by default "any" - return_negative : bool, optional - Whether to return the negative sample during the fit stage - (can be set to False when using a loss function like NT-Xent), - by default True - persistent_workers : bool, optional - Whether to keep worker processes alive between iterations, by default False - prefetch_factor : int | None, optional - Number of batches loaded in advance by each worker, by default None - pin_memory : bool, optional - Whether to pin memory in CPU for faster GPU transfer, by default False - z_window_size : int, optional - Size of the final Z window, by default None (inferred from z_range) - """ super().__init__( data_path=data_path, source_channel=source_channel, diff --git a/viscy/preprocessing/generate_masks.py b/viscy/preprocessing/generate_masks.py index 8614ee13e..9cc4e9ae9 100644 --- a/viscy/preprocessing/generate_masks.py +++ b/viscy/preprocessing/generate_masks.py @@ -1,4 +1,4 @@ -"""Generate masks from sum of flurophore channels""" +"""Generate masks from sum of flurophore channels.""" from pathlib import Path from typing import Literal @@ -9,7 +9,26 @@ class MaskProcessor: - """Appends Masks to zarr directories""" + """Appends Masks to zarr directories. + + Parameters + ---------- + zarr_dir : Path + Directory of HCS zarr store to pull data from. Note: data in store is assumed to be stored in TCZYX format. + channel_ids : list[int] | int + Channel indices to be masked (typically just one) + time_ids : list[int] | int + Timepoints to consider + pos_ids : list[int] | int + Position (FOV) indices to use + num_workers : int, optional + Number of workers for multiprocessing, by default 4 + mask_type : Literal["otsu", "unimodal", "mem_detection", "borders_weight_loss_map"], optional + Method to use for generating mask. Needed for mapping to the masking function. + One of: {'otsu', 'unimodal', 'mem_detection', 'borders_weight_loss_map'}, by default "otsu". + overwrite_ok : bool, optional + Overwrite existing masks, by default False. + """ def __init__( self, @@ -23,26 +42,6 @@ def __init__( ] = "otsu", overwrite_ok: bool = False, ): - """Initialize mask processor for generating masks from fluorophore channels. - - Parameters - ---------- - zarr_dir : str - Directory of HCS zarr store to pull data from. Note: data in store is assumed to be stored in TCZYX format. - channel_ids : list[int] | int - Channel indices to be masked (typically just one) - time_ids : list[int] | int - Timepoints to consider - pos_ids : list[int] | int - Position (FOV) indices to use - num_workers : int - Number of workers for multiprocessing - mask_type : str - Method to use for generating mask. Needed for mapping to the masking function. - One of: {'otsu', 'unimodal', 'mem_detection', 'borders_weight_loss_map'}. Default is 'otsu'. - overwrite_ok : bool - Overwrite existing masks. Default is False. - """ self.zarr_dir = zarr_dir self.num_workers = num_workers diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py index b7701849c..36636e0a6 100644 --- a/viscy/preprocessing/pixel_ratio.py +++ b/viscy/preprocessing/pixel_ratio.py @@ -1,3 +1,5 @@ +"""Pixel ratio utilities for class balancing in semantic segmentation.""" + import dask.array as da from iohub.ngff import open_ome_zarr from numpy.typing import NDArray @@ -6,7 +8,7 @@ def sematic_class_weights( dataset_path: str, target_channel: str, num_classes: int = 3 ) -> NDArray: - """Computes class balancing weights for semantic segmentation. + """Compute class balancing weights for semantic segmentation. The weights can be used for cross-entropy loss. diff --git a/viscy/preprocessing/precompute.py b/viscy/preprocessing/precompute.py index a23aa1e57..721473291 100644 --- a/viscy/preprocessing/precompute.py +++ b/viscy/preprocessing/precompute.py @@ -1,4 +1,4 @@ -"""Precompute normalization and store a plain C array""" +"""Precompute normalization and store a plain C array.""" from __future__ import annotations diff --git a/viscy/representation/classification.py b/viscy/representation/classification.py index ae64f2926..aae9a10ab 100644 --- a/viscy/representation/classification.py +++ b/viscy/representation/classification.py @@ -1,3 +1,5 @@ +"""Classification module for binary classification tasks.""" + from pathlib import Path from typing import Any @@ -17,16 +19,14 @@ class ClassificationPredictionWriter(BasePredictionWriter): Collects predictions from all batches and writes them to a CSV file at the end of each epoch. Converts tensor outputs to numpy arrays for storage. + + Parameters + ---------- + output_path : Path + Path to the output CSV file. """ def __init__(self, output_path: Path) -> None: - """Initialize the prediction writer. - - Parameters - ---------- - output_path : Path - Path to the output CSV file. - """ super().__init__("epoch") if Path(output_path).exists(): raise FileExistsError(f"Output path {output_path} already exists.") @@ -67,6 +67,15 @@ class ClassificationModule(LightningModule): Adapts a contrastive encoder for binary classification by replacing the final linear layer and adding classification-specific training logic. Computes binary cross-entropy loss and tracks accuracy and F1-score metrics. + + Parameters + ---------- + encoder : ContrastiveEncoder + Contrastive encoder model. + lr : float | None + Learning rate. + loss : nn.Module | None + Loss function. By default, BCEWithLogitsLoss with positive weight of 1.0. """ def __init__( @@ -75,17 +84,6 @@ def __init__( lr: float | None, loss: nn.Module | None = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1.0)), ) -> None: - """Initialize the classification module. - - Parameters - ---------- - encoder : ContrastiveEncoder - Contrastive encoder model. - lr : float | None - Learning rate. - loss : nn.Module | None - Loss function. By default, BCEWithLogitsLoss with positive weight of 1.0. - """ super().__init__() self.stem = encoder.stem self.backbone = encoder.encoder diff --git a/viscy/representation/contrastive.py b/viscy/representation/contrastive.py index df6094ee9..165a6bc00 100644 --- a/viscy/representation/contrastive.py +++ b/viscy/representation/contrastive.py @@ -1,3 +1,5 @@ +"""Contrastive encoder network that uses ConvNeXt v1 and ResNet backbones from timm.""" + from typing import Literal import timm diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 0e4db683c..f874d60f4 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -1,3 +1,5 @@ +"""Embedding writer module for writing embeddings to a zarr store in an Xarray-compatible format.""" + import logging from collections.abc import Sequence from pathlib import Path diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index a6d3f3db9..3ad7e7da1 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,3 +1,5 @@ +"""Contrastive learning model for self-supervised learning.""" + import logging from collections.abc import Sequence from typing import Literal, TypedDict diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index fae2e2248..86fc3b7a3 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,3 +1,5 @@ +"""Distance evaluation module for computing displacement and cosine similarities.""" + from collections import defaultdict from typing import Literal diff --git a/viscy/representation/evaluation/feature.py b/viscy/representation/evaluation/feature.py index c7ec70647..1bb3ba5e6 100644 --- a/viscy/representation/evaluation/feature.py +++ b/viscy/representation/evaluation/feature.py @@ -1,3 +1,5 @@ +"""Feature extraction module for computing various features from a single cell image patch.""" + from typing import TypedDict import mahotas as mh diff --git a/viscy/representation/evaluation/visualization.py b/viscy/representation/evaluation/visualization.py index df08d0e56..5ae1ce660 100644 --- a/viscy/representation/evaluation/visualization.py +++ b/viscy/representation/evaluation/visualization.py @@ -1,3 +1,5 @@ +"""Interactive visualization app for embedding analysis.""" + import atexit import base64 import json diff --git a/viscy/representation/multi_modal.py b/viscy/representation/multi_modal.py index 51e429a45..56d543804 100644 --- a/viscy/representation/multi_modal.py +++ b/viscy/representation/multi_modal.py @@ -1,3 +1,5 @@ +"""Joint multi-modal encoders for cross-modal representation learning.""" + from collections.abc import Sequence from logging import getLogger from typing import Literal diff --git a/viscy/trainer.py b/viscy/trainer.py index 5f12db396..caf017e95 100644 --- a/viscy/trainer.py +++ b/viscy/trainer.py @@ -1,3 +1,5 @@ +"""Extended Lightning Trainer for VisCy with preprocessing and export capabilities.""" + import logging from pathlib import Path from typing import Literal diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index 40712f9e1..6ca88eb4b 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -1,3 +1,5 @@ +"""VisCy transform package for data preprocessing and augmentation.""" + from viscy.transforms._redef import ( CenterSpatialCropd, Decollated, diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index 9171ffd03..a4031ea8f 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -21,6 +21,11 @@ class Decollated(Decollated): + """Decollate data wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -39,11 +44,21 @@ def __init__( class ToDeviced(ToDeviced): + """Transfer data to device wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__(self, keys: Sequence[str] | str, **kwargs: Any) -> None: super().__init__(keys=keys, **kwargs) class RandWeightedCropd(RandWeightedCropd): + """Random weighted crop wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -62,6 +77,11 @@ def __init__( class RandAffined(RandAffined): + """Random affine transform wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -82,6 +102,11 @@ def __init__( class RandAdjustContrastd(RandAdjustContrastd): + """Random contrast adjustment wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -93,6 +118,11 @@ def __init__( class RandScaleIntensityd(RandScaleIntensityd): + """Random intensity scaling wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -104,6 +134,11 @@ def __init__( class RandGaussianNoised(RandGaussianNoised): + """Random Gaussian noise wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -116,6 +151,11 @@ def __init__( class RandGaussianSmoothd(RandGaussianSmoothd): + """Random Gaussian smoothing wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -136,6 +176,11 @@ def __init__( class ScaleIntensityRangePercentilesd(ScaleIntensityRangePercentilesd): + """Scale intensity by percentile range wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -164,6 +209,11 @@ def __init__( class RandSpatialCropd(RandSpatialCropd): + """Random spatial crop wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -180,6 +230,11 @@ def __init__( class CenterSpatialCropd(CenterSpatialCropd): + """Center spatial crop wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, @@ -190,6 +245,11 @@ def __init__( class RandFlipd(RandFlipd): + """Random flip wrapper for jsonargparse compatibility. + + See parent class documentation for details. + """ + def __init__( self, keys: Sequence[str] | str, diff --git a/viscy/translation/engine.py b/viscy/translation/engine.py index 40eb0dd4c..1cc4e23ee 100644 --- a/viscy/translation/engine.py +++ b/viscy/translation/engine.py @@ -1,3 +1,5 @@ +"""Training engine for virtual staining and image translation models.""" + import logging import os import random @@ -549,7 +551,7 @@ def on_test_start(self) -> None: ) def on_predict_start(self) -> None: - """Setup prediction padding transform. + """Set up prediction padding transform. Pad the input shape to be divisible by the downsampling factor. The inverse of this transform crops the prediction to original shape. @@ -684,7 +686,7 @@ def forward(self, x: Tensor) -> Tensor: return self.model(x) def setup(self, stage: str) -> None: - """Setup method for Lightning module. + """Set up the Lightning module for the specified stage. Parameters ---------- @@ -782,7 +784,7 @@ def __init__( self.save_hyperparameters(ignore=["loss_function"]) def on_fit_start(self) -> None: - """Setup data modules and validate configuration for training. + """Set up data modules and validate configuration for training. Raises ------ diff --git a/viscy/unet/__init__.py b/viscy/unet/__init__.py index e69de29bb..5e46d0396 100644 --- a/viscy/unet/__init__.py +++ b/viscy/unet/__init__.py @@ -0,0 +1 @@ +"""U-Net architectures for VisCy.""" diff --git a/viscy/unet/networks/Unet25D.py b/viscy/unet/networks/Unet25D.py index 41cda8628..9d9c6fb1a 100644 --- a/viscy/unet/networks/Unet25D.py +++ b/viscy/unet/networks/Unet25D.py @@ -1,3 +1,5 @@ +"""2.5D U-Net implementation for volumetric image processing.""" + from typing import Literal import torch @@ -11,9 +13,46 @@ class Unet25d(nn.Module): A hybrid approach that processes 3D input stacks but outputs 2D predictions. Combines 3D spatial information with 2D computational efficiency. + + Architecture takes in stack of 2D inputs given as a 3D tensor + and returns a 2D interpretation. Learns 3D information based upon input stack, + but speeds up training by compressing 3D information before the decoding path. + Uses interruption conv layers in the U-Net skip paths to + compress information with z-channel convolution. + + References + ---------- + https://elifesciences.org/articles/55502 + + Parameters + ---------- + in_channels : int, optional + Number of feature channels in (1 or more), by default 1. + out_channels : int, optional + Number of feature channels out (1 or more), by default 1. + in_stack_depth : int, optional + Depth of input stack in z, by default 5. + out_stack_depth : int, optional + Depth of output stack, by default 1. + xy_kernel_size : int or tuple of int, optional + Size of x and y dimensions of conv kernels in blocks, by default (3, 3). + residual : bool, optional + Whether to use residual connections, by default False. + dropout : float, optional + Probability of dropout, between 0 and 0.5, by default 0.2. + num_blocks : int, optional + Number of convolutional blocks on encoder and decoder paths, by default 4. + num_block_layers : int, optional + Number of layer sequences repeated per block, by default 2. + num_filters : list of int, optional + List of filters/feature levels at each conv block depth, by default []. + task : str, optional + Network task (for virtual staining this is regression), + one of 'seg','reg', by default "seg". """ def __name__(self) -> str: + """Return the name of the network architecture.""" return "Unet25d" def __init__( @@ -30,44 +69,6 @@ def __init__( num_filters: list[int] = [], task: Literal["seg", "reg"] = "seg", ) -> None: - """Initialize 2.5D U-Net. - - Architecture takes in stack of 2D inputs given as a 3D tensor - and returns a 2D interpretation. Learns 3D information based upon input stack, - but speeds up training by compressing 3D information before the decoding path. - Uses interruption conv layers in the U-Net skip paths to - compress information with z-channel convolution. - - References - ---------- - https://elifesciences.org/articles/55502 - - Parameters - ---------- - in_channels : int, optional - Number of feature channels in (1 or more), by default 1. - out_channels : int, optional - Number of feature channels out (1 or more), by default 1. - in_stack_depth : int, optional - Depth of input stack in z, by default 5. - out_stack_depth : int, optional - Depth of output stack, by default 1. - xy_kernel_size : int or tuple of int, optional - Size of x and y dimensions of conv kernels in blocks, by default (3, 3). - residual : bool, optional - Whether to use residual connections, by default False. - dropout : float, optional - Probability of dropout, between 0 and 0.5, by default 0.2. - num_blocks : int, optional - Number of convolutional blocks on encoder and decoder paths, by default 4. - num_block_layers : int, optional - Number of layer sequences repeated per block, by default 2. - num_filters : list of int, optional - List of filters/feature levels at each conv block depth, by default []. - task : str, optional - Network task (for virtual staining this is regression), - one of 'seg','reg', by default "seg". - """ super().__init__() self.in_channels = in_channels self.num_blocks = num_blocks @@ -265,7 +266,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x def register_modules(self, module_list: list[nn.Module], name: str) -> None: - """Helper function that registers modules stored in a list to the model object. + """Register modules stored in a list to the model object. So that they can be seen by PyTorch optimizer. diff --git a/viscy/unet/networks/Unet2D.py b/viscy/unet/networks/Unet2D.py index 225b43523..2d3b5ac22 100644 --- a/viscy/unet/networks/Unet2D.py +++ b/viscy/unet/networks/Unet2D.py @@ -1,3 +1,5 @@ +"""2D U-Net implementation for image-to-image translation tasks.""" + import torch import torch.nn as nn @@ -9,9 +11,39 @@ class Unet2d(nn.Module): A convolutional neural network following the U-Net architecture for 2D images. Supports both segmentation and regression tasks with configurable depth and filters. + + Follows 2D UNet Architecture: + + References + ---------- + 1) U-Net: https://arxiv.org/pdf/1505.04597.pdf + 2) Residual U-Net: https://arxiv.org/pdf/1711.10684.pdf + + Parameters + ---------- + in_channels : int, optional + Number of feature channels in, by default 1. + out_channels : int, optional + Number of feature channels out, by default 1. + kernel_size : int or tuple of int, optional + Size of x and y dimensions of conv kernels in blocks, by default (3, 3). + residual : bool, optional + Whether to use residual connections, by default False. + dropout : float, optional + Probability of dropout, between 0 and 0.5, by default 0.2. + num_blocks : int, optional + Number of convolutional blocks on encoder and decoder, by default 4. + num_block_layers : int, optional + Number of layers per block, by default 2. + num_filters : list of int, optional + List of filters/feature levels at each conv block depth, by default []. + task : str, optional + Network task (for virtual staining this is regression), + one of 'seg','reg', by default "seg". """ def __name__(self): + """Return the name of the network architecture.""" return "Unet2d" def __init__( @@ -26,37 +58,6 @@ def __init__( num_filters=[], task="seg", ): - """Initialize 2D U-Net with variable input/output channels and depth. - - Follows 2D UNet Architecture: - - References - ---------- - 1) U-Net: https://arxiv.org/pdf/1505.04597.pdf - 2) Residual U-Net: https://arxiv.org/pdf/1711.10684.pdf - - Parameters - ---------- - in_channels : int, optional - Number of feature channels in, by default 1. - out_channels : int, optional - Number of feature channels out, by default 1. - kernel_size : int or tuple of int, optional - Size of x and y dimensions of conv kernels in blocks, by default (3, 3). - residual : bool, optional - Whether to use residual connections, by default False. - dropout : float, optional - Probability of dropout, between 0 and 0.5, by default 0.2. - num_blocks : int, optional - Number of convolutional blocks on encoder and decoder, by default 4. - num_block_layers : int, optional - Number of layers per block, by default 2. - num_filters : list of int, optional - List of filters/feature levels at each conv block depth, by default []. - task : str, optional - Network task (for virtual staining this is regression), - one of 'seg','reg', by default "seg". - """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -235,7 +236,7 @@ def forward(self, x: torch.Tensor, validate_input: bool = False) -> torch.Tensor return x.unsqueeze(2) def register_modules(self, module_list: list[nn.Module], name: str) -> None: - """Helper function that registers modules stored in a list to the model object. + """Register modules stored in a list to the model object. So that they can be seen by PyTorch optimizer. diff --git a/viscy/unet/networks/__init__.py b/viscy/unet/networks/__init__.py index e69de29bb..c21eefda3 100644 --- a/viscy/unet/networks/__init__.py +++ b/viscy/unet/networks/__init__.py @@ -0,0 +1 @@ +"""Neural network architectures for VisCy U-Net implementations.""" diff --git a/viscy/unet/networks/layers/ConvBlock3D.py b/viscy/unet/networks/layers/ConvBlock3D.py index 3840c43a0..6895d8baa 100644 --- a/viscy/unet/networks/layers/ConvBlock3D.py +++ b/viscy/unet/networks/layers/ConvBlock3D.py @@ -1,3 +1,5 @@ +"""3D convolutional blocks for volumetric neural network architectures.""" + from typing import Literal import numpy as np @@ -246,7 +248,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Forward call of convolutional block + Forward call of convolutional block. Order of layers within the block is defined by the 'layer_order' parameter, which is a string of 'c's, 'a's and 'n's in reference to @@ -320,7 +322,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def model(self) -> nn.Sequential: """ - Allows calling of parameters inside ConvBlock object. + Create sequential model from ConvBlock parameters. Layer order: convolution -> normalization -> activation @@ -351,7 +353,7 @@ def model(self) -> nn.Sequential: def register_modules(self, module_list: list[nn.Module], name: str) -> None: """ - Helper function that registers modules for PyTorch optimizer visibility. + Register modules for PyTorch optimizer visibility. Used to enable model graph creation with non-sequential model types and dynamic layer numbers diff --git a/viscy/unet/networks/layers/__init__.py b/viscy/unet/networks/layers/__init__.py index e69de29bb..54f8699e9 100644 --- a/viscy/unet/networks/layers/__init__.py +++ b/viscy/unet/networks/layers/__init__.py @@ -0,0 +1 @@ +"""Convolutional building blocks for neural network layers.""" diff --git a/viscy/unet/networks/unext2.py b/viscy/unet/networks/unext2.py index e100ddc8c..cf4b925da 100644 --- a/viscy/unet/networks/unext2.py +++ b/viscy/unet/networks/unext2.py @@ -1,3 +1,5 @@ +"""UNeXt2: ConvNeXt-based U-Net implementation for advanced neural network architectures.""" + from collections.abc import Callable, Sequence from typing import Literal @@ -404,7 +406,7 @@ def forward(self, x: Tensor) -> Tensor: class UnsqueezeHead(nn.Module): - """Unsqueeze 2D (B, C, H, W) feature map to 3D (B, C, 1, H, W) output""" + """Unsqueeze 2D (B, C, H, W) feature map to 3D (B, C, 1, H, W) output.""" def __init__(self) -> None: super().__init__() @@ -599,7 +601,7 @@ def __init__( @property def num_blocks(self) -> int: - """2-times downscaling factor of the smallest feature map""" + """2-times downscaling factor of the smallest feature map.""" return 6 def forward(self, x: Tensor) -> Tensor: diff --git a/viscy/utils/__init__.py b/viscy/utils/__init__.py index 5e2c7e1ed..b0ba3e69f 100644 --- a/viscy/utils/__init__.py +++ b/viscy/utils/__init__.py @@ -1 +1 @@ -"""Module for utility functions""" +"""Module for utility functions.""" diff --git a/viscy/utils/cli_utils.py b/viscy/utils/cli_utils.py index a88815172..96373fe70 100644 --- a/viscy/utils/cli_utils.py +++ b/viscy/utils/cli_utils.py @@ -53,7 +53,7 @@ class MultiProcessProgressBar: Parameters ---------- total_updates : int - Total number of updates. + Total number of updates expected for this progress bar. """ def __init__(self, total_updates: int) -> None: diff --git a/viscy/utils/logging.py b/viscy/utils/logging.py index e00821d63..25b9897fe 100644 --- a/viscy/utils/logging.py +++ b/viscy/utils/logging.py @@ -1,3 +1,5 @@ +"""Feature map logging utilities for neural network debugging.""" + import datetime import os import time @@ -56,8 +58,7 @@ def log_feature( class FeatureLogger: - """ - Logger for visualizing neural network feature maps during training and debugging. + """Logger for visualizing neural network feature maps during training and debugging. This utility class provides comprehensive feature map visualization capabilities for monitoring convolutional neural network activations. It supports both @@ -69,20 +70,38 @@ class FeatureLogger: It handles multi-dimensional tensors commonly found in computer vision tasks, including 2D/3D spatial dimensions with batch and channel axes. + Parameters + ---------- + save_folder : str + Output directory for saving visualization files. + spatial_dims : int, optional + Number of spatial dimensions in feature tensors, by default 3. + full_batch : bool, optional + If true, log all samples in batch (warning: slow!), by default False. + save_as_grid : bool, optional + If true, feature maps are saved as a grid containing all channels, + else saved individually, by default True. + grid_width : int, optional + Desired width of grid if save_as_grid. If 0, defaults to 1/4 the + number of channels, by default 0. + normalize_by_grid : bool, optional + If true, images saved in grid are normalized to brightest pixel in + entire grid, by default False. + Attributes ---------- save_folder : str - Directory path for saving visualization outputs + Directory path for saving visualization outputs. spatial_dims : int - Number of spatial dimensions in feature tensors (2D or 3D) + Number of spatial dimensions in feature tensors (2D or 3D). full_batch : bool - Whether to log all samples in batch or just the first + Whether to log all samples in batch or just the first. save_as_grid : bool - Whether to arrange channels in a grid layout + Whether to arrange channels in a grid layout. grid_width : int - Number of columns in grid visualization + Number of columns in grid visualization. normalize_by_grid : bool - Whether to normalize intensities across entire grid + Whether to normalize intensities across entire grid. Examples -------- @@ -106,29 +125,6 @@ def __init__( grid_width: int = 0, normalize_by_grid: bool = False, ) -> None: - """Initialize logger for handling feature map visualization in neural networks. - - Saves each 2D slice of a feature map in either a single grid per feature map - stack or a directory tree of labeled slices. By default saves images into grid. - - Parameters - ---------- - save_folder : str - Output directory for saving visualization files. - spatial_dims : int, optional - Number of spatial dimensions in feature tensors, by default 3. - full_batch : bool, optional - If true, log all samples in batch (warning: slow!), by default False. - save_as_grid : bool, optional - If true, feature maps are saved as a grid containing all channels, - else saved individually, by default True. - grid_width : int, optional - Desired width of grid if save_as_grid. If 0, defaults to 1/4 the - number of channels, by default 0. - normalize_by_grid : bool, optional - If true, images saved in grid are normalized to brightest pixel in - entire grid, by default False. - """ self.save_folder = save_folder self.spatial_dims = spatial_dims self.full_batch = full_batch diff --git a/viscy/utils/masks.py b/viscy/utils/masks.py index f000fce94..aba09fc83 100644 --- a/viscy/utils/masks.py +++ b/viscy/utils/masks.py @@ -1,3 +1,5 @@ +"""Mask generation and processing utilities.""" + from typing import Any import numpy as np diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index 9cb753f78..d5a729749 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -1,3 +1,5 @@ +"""Metadata utilities for dataset analysis and normalization statistics.""" + import os import sys from pathlib import Path diff --git a/viscy/utils/mp_utils.py b/viscy/utils/mp_utils.py index 686dade41..7a7239379 100644 --- a/viscy/utils/mp_utils.py +++ b/viscy/utils/mp_utils.py @@ -1,3 +1,5 @@ +"""Multiprocessing utilities for parallel data processing.""" + from collections.abc import Callable from concurrent.futures import ProcessPoolExecutor from typing import Any @@ -278,8 +280,7 @@ def get_mask_slice( def mp_get_val_stats(fn_args: list[Any], workers: int) -> list[dict[str, float]]: - """ - Compute statistics of numpy arrays with multiprocessing + """Compute statistics of numpy arrays with multiprocessing. Parameters ---------- @@ -327,7 +328,7 @@ def get_val_stats(sample_values: list[float]) -> dict[str, float]: def mp_sample_im_pixels( fn_args: list[tuple[Any, ...]], workers: int ) -> list[list[Any]]: - """Read and computes statistics of images with multiprocessing + """Read and compute statistics of images with multiprocessing. Parameters ---------- diff --git a/viscy/utils/slurm_utils.py b/viscy/utils/slurm_utils.py index c943b8b11..f2fc2bdf8 100644 --- a/viscy/utils/slurm_utils.py +++ b/viscy/utils/slurm_utils.py @@ -1,3 +1,5 @@ +"""SLURM cluster utilities for resource management.""" + import psutil import torch From ae89d6a63cd5143302f7262fcacc4ead9d8cca3a Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Thu, 11 Sep 2025 11:25:31 -0700 Subject: [PATCH 04/13] moved ruff.toml contents into pyproject.toml --- pyproject.toml | 33 ++ ruff.toml | 38 -- viscy/__init__.py | 1 - viscy/cli.py | 2 - viscy/data/__init__.py | 1 - viscy/data/cell_classification.py | 2 - viscy/data/combined.py | 7 - viscy/data/ctmc_v1.py | 2 - viscy/data/distributed.py | 2 - viscy/data/gpu_aug.py | 2 - viscy/data/hcs.py | 2 - viscy/data/livecell.py | 2 - viscy/data/mmap_cache.py | 2 - viscy/data/select.py | 2 - viscy/data/triplet.py | 2 - viscy/data/typing.py | 2 - viscy/preprocessing/pixel_ratio.py | 2 - viscy/preprocessing/precompute.py | 2 - viscy/representation/classification.py | 2 - viscy/representation/contrastive.py | 2 - viscy/representation/embedding_writer.py | 2 - viscy/representation/engine.py | 2 - viscy/representation/evaluation/clustering.py | 2 - .../evaluation/dimensionality_reduction.py | 2 - viscy/representation/evaluation/distance.py | 2 - viscy/representation/evaluation/feature.py | 2 - .../evaluation/visualization.py | 2 - viscy/representation/multi_modal.py | 2 - viscy/trainer.py | 2 - viscy/translation/evaluation_metrics.py | 419 ++++++------------ viscy/translation/predict_writer.py | 2 - viscy/unet/__init__.py | 1 - viscy/unet/networks/Unet25D.py | 2 - viscy/unet/networks/__init__.py | 1 - viscy/unet/networks/layers/ConvBlock2D.py | 2 - viscy/unet/networks/layers/ConvBlock3D.py | 2 - viscy/unet/networks/layers/__init__.py | 1 - viscy/unet/networks/unext2.py | 2 - viscy/utils/__init__.py | 1 - viscy/utils/cli_utils.py | 2 - viscy/utils/log_images.py | 2 - viscy/utils/logging.py | 2 - viscy/utils/masks.py | 2 - viscy/utils/meta_utils.py | 8 +- viscy/utils/mp_utils.py | 2 - viscy/utils/slurm_utils.py | 2 - 46 files changed, 178 insertions(+), 403 deletions(-) delete mode 100644 ruff.toml diff --git a/pyproject.toml b/pyproject.toml index 04b11a58a..67a958a14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,3 +74,36 @@ packages = ["viscy"] [tool.setuptools_scm] write_to = "viscy/_version.py" +[tool.ruff] +line-length = 88 +src = ["viscy", "tests"] +extend-include = ["*.ipynb"] +target-version = "py311" +# Exclude the following for now. Later on we should check every Python file. +extend-exclude = ["viscy/scripts/*", "applications/*", "examples/*"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +docstring-code-format = true +docstring-code-line-length = "dynamic" + +[tool.ruff.lint] +select = [ + "D", # pydocstyle + "I", # isort +] +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # __magic__ methods are often self-explanatory, allow missing docstrings + "D107", # Missing docstring in __init__ + # Disable one in each pair of mutually incompatible rules + "D203", # We don’t want a blank line before a class docstring + "D213", # <> We want docstrings to start immediately after the opening triple quote + "D400", # first line should end with a period [Bug: doesn’t work with single-line docstrings] + "D401", # First line should be in imperative mood; try rephrasing +] +per-file-ignores."*/__init__.py" = ["F401"] +per-file-ignores."tests/*" = ["D"] +pydocstyle.convention = "numpy" \ No newline at end of file diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 7e3f1e007..000000000 --- a/ruff.toml +++ /dev/null @@ -1,38 +0,0 @@ -# This file is used to configure the Ruff linter and formatter: -# View the documentation for more information on how to configure this file below -# https://docs.astral.sh/ruff/linter/ -# https://docs.astral.sh/ruff/formatter/ - - -line-length = 88 -src = ["viscy", "tests"] -extend-include = ["*.ipynb"] -target-version = "py310" -# Exclude the following for now. Later on we should check every Python file, no exceptions. -extend-exclude = ["viscy/scripts/*", "applications/*", "examples/*"] - -[format] -quote-style = "double" -indent-style = "space" -docstring-code-format = true -docstring-code-line-length = "dynamic" - -[lint] -select = [ - "D", # pydocstyle - "I", # isort -] -ignore = [ - # "D100", # Missing docstring in public module - # "D104", # Missing docstring in public package - "D105", # __magic__ methods are often self-explanatory, allow missing docstrings - "D107", # Missing docstring in __init__ - # Disable one in each pair of mutually incompatible rules - "D203", # We don’t want a blank line before a class docstring - "D213", # <> We want docstrings to start immediately after the opening triple quote - "D400", # first line should end with a period [Bug: doesn’t work with single-line docstrings] - "D401", # First line should be in imperative mood; try rephrasing -] -per-file-ignores."*/__init__.py" = ["F401"] -per-file-ignores."tests/*" = ["D"] -pydocstyle.convention = "numpy" \ No newline at end of file diff --git a/viscy/__init__.py b/viscy/__init__.py index 5f1ef031e..e69de29bb 100644 --- a/viscy/__init__.py +++ b/viscy/__init__.py @@ -1 +0,0 @@ -"""Learning vision for cells.""" diff --git a/viscy/cli.py b/viscy/cli.py index ad6537e98..f85d30786 100644 --- a/viscy/cli.py +++ b/viscy/cli.py @@ -1,5 +1,3 @@ -"""Lightning CLI for computer vision models in VisCy.""" - import logging import os import sys diff --git a/viscy/data/__init__.py b/viscy/data/__init__.py index f8b10be0b..e69de29bb 100644 --- a/viscy/data/__init__.py +++ b/viscy/data/__init__.py @@ -1 +0,0 @@ -"""VisCy data loading and preprocessing modules.""" diff --git a/viscy/data/cell_classification.py b/viscy/data/cell_classification.py index 292e0507f..7b42cd82b 100644 --- a/viscy/data/cell_classification.py +++ b/viscy/data/cell_classification.py @@ -1,5 +1,3 @@ -"""Dataset and DataModule classes for cell classification tasks.""" - from collections.abc import Callable from pathlib import Path diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 90e111729..9fbdfab9c 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,10 +1,3 @@ -"""Combined data modules for multi-dataset ML training workflows. - -This module provides Lightning DataModule implementations for combining multiple -data sources with various strategies including concatenation, batching, and -distributed sampling optimizations for computer vision and microscopy datasets. -""" - import bisect import logging from collections import defaultdict diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 811a1c6ba..468b0d676 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -1,5 +1,3 @@ -"""Data module for CTMCv1 autoregression dataset with HCS OME-Zarr stores.""" - from pathlib import Path import torch diff --git a/viscy/data/distributed.py b/viscy/data/distributed.py index 8da588fea..beab41dc2 100644 --- a/viscy/data/distributed.py +++ b/viscy/data/distributed.py @@ -1,5 +1,3 @@ -"""Utilities for DDP training.""" - from __future__ import annotations import math diff --git a/viscy/data/gpu_aug.py b/viscy/data/gpu_aug.py index f5e8ecb4e..abca552e5 100644 --- a/viscy/data/gpu_aug.py +++ b/viscy/data/gpu_aug.py @@ -1,5 +1,3 @@ -"""GPU-accelerated data augmentation modules for microscopy ML training.""" - from __future__ import annotations from abc import ABC, abstractmethod diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 4cec0033b..785e24394 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -1,5 +1,3 @@ -"""High-Content Screening (HCS) data loading and preprocessing module.""" - import logging import math import os diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py index 50e1c201d..d7134fe87 100644 --- a/viscy/data/livecell.py +++ b/viscy/data/livecell.py @@ -1,5 +1,3 @@ -"""LiveCell dataset implementation for cell segmentation benchmarking.""" - from __future__ import annotations import json diff --git a/viscy/data/mmap_cache.py b/viscy/data/mmap_cache.py index 10e967c10..b3cf427f4 100644 --- a/viscy/data/mmap_cache.py +++ b/viscy/data/mmap_cache.py @@ -1,5 +1,3 @@ -"""Memory-mapped caching for OME-Zarr data with efficient disk I/O.""" - from __future__ import annotations import os diff --git a/viscy/data/select.py b/viscy/data/select.py index 509ac2440..4a0e4539b 100644 --- a/viscy/data/select.py +++ b/viscy/data/select.py @@ -1,5 +1,3 @@ -"""Well and field-of-view selection utilities for HCS datasets.""" - from collections.abc import Generator from iohub.ngff.nodes import Plate, Position, Well diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 3fbdb3be6..b899ac15e 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -1,5 +1,3 @@ -"""Triplet sampling for contrastive learning on tracked cell data.""" - import logging from collections.abc import Sequence from pathlib import Path diff --git a/viscy/data/typing.py b/viscy/data/typing.py index 3d800953a..eafb04d09 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -1,5 +1,3 @@ -"""Type definitions for VisCy data modules and structures.""" - from collections.abc import Callable, Sequence from typing import NamedTuple, TypedDict, TypeVar diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py index 36636e0a6..285a15c09 100644 --- a/viscy/preprocessing/pixel_ratio.py +++ b/viscy/preprocessing/pixel_ratio.py @@ -1,5 +1,3 @@ -"""Pixel ratio utilities for class balancing in semantic segmentation.""" - import dask.array as da from iohub.ngff import open_ome_zarr from numpy.typing import NDArray diff --git a/viscy/preprocessing/precompute.py b/viscy/preprocessing/precompute.py index 721473291..cc94b69f6 100644 --- a/viscy/preprocessing/precompute.py +++ b/viscy/preprocessing/precompute.py @@ -1,5 +1,3 @@ -"""Precompute normalization and store a plain C array.""" - from __future__ import annotations from pathlib import Path diff --git a/viscy/representation/classification.py b/viscy/representation/classification.py index aae9a10ab..cb686f3b1 100644 --- a/viscy/representation/classification.py +++ b/viscy/representation/classification.py @@ -1,5 +1,3 @@ -"""Classification module for binary classification tasks.""" - from pathlib import Path from typing import Any diff --git a/viscy/representation/contrastive.py b/viscy/representation/contrastive.py index 165a6bc00..df6094ee9 100644 --- a/viscy/representation/contrastive.py +++ b/viscy/representation/contrastive.py @@ -1,5 +1,3 @@ -"""Contrastive encoder network that uses ConvNeXt v1 and ResNet backbones from timm.""" - from typing import Literal import timm diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index f874d60f4..0e4db683c 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -1,5 +1,3 @@ -"""Embedding writer module for writing embeddings to a zarr store in an Xarray-compatible format.""" - import logging from collections.abc import Sequence from pathlib import Path diff --git a/viscy/representation/engine.py b/viscy/representation/engine.py index 3ad7e7da1..a6d3f3db9 100644 --- a/viscy/representation/engine.py +++ b/viscy/representation/engine.py @@ -1,5 +1,3 @@ -"""Contrastive learning model for self-supervised learning.""" - import logging from collections.abc import Sequence from typing import Literal, TypedDict diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index dbdc6455c..afe09a8cd 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -1,5 +1,3 @@ -"""Methods for evaluating clustering performance.""" - import numpy as np from numpy.typing import ArrayLike, NDArray from scipy.spatial.distance import cdist diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 5b0db1cb7..ed1c9c47c 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -1,5 +1,3 @@ -"""PCA and UMAP dimensionality reduction.""" - import pandas as pd import umap import xarray as xr diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index 86fc3b7a3..fae2e2248 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,5 +1,3 @@ -"""Distance evaluation module for computing displacement and cosine similarities.""" - from collections import defaultdict from typing import Literal diff --git a/viscy/representation/evaluation/feature.py b/viscy/representation/evaluation/feature.py index 1bb3ba5e6..c7ec70647 100644 --- a/viscy/representation/evaluation/feature.py +++ b/viscy/representation/evaluation/feature.py @@ -1,5 +1,3 @@ -"""Feature extraction module for computing various features from a single cell image patch.""" - from typing import TypedDict import mahotas as mh diff --git a/viscy/representation/evaluation/visualization.py b/viscy/representation/evaluation/visualization.py index 5ae1ce660..df08d0e56 100644 --- a/viscy/representation/evaluation/visualization.py +++ b/viscy/representation/evaluation/visualization.py @@ -1,5 +1,3 @@ -"""Interactive visualization app for embedding analysis.""" - import atexit import base64 import json diff --git a/viscy/representation/multi_modal.py b/viscy/representation/multi_modal.py index 56d543804..51e429a45 100644 --- a/viscy/representation/multi_modal.py +++ b/viscy/representation/multi_modal.py @@ -1,5 +1,3 @@ -"""Joint multi-modal encoders for cross-modal representation learning.""" - from collections.abc import Sequence from logging import getLogger from typing import Literal diff --git a/viscy/trainer.py b/viscy/trainer.py index caf017e95..5f12db396 100644 --- a/viscy/trainer.py +++ b/viscy/trainer.py @@ -1,5 +1,3 @@ -"""Extended Lightning Trainer for VisCy with preprocessing and export capabilities.""" - import logging from pathlib import Path from typing import Literal diff --git a/viscy/translation/evaluation_metrics.py b/viscy/translation/evaluation_metrics.py index 7403e5f08..2011b4def 100644 --- a/viscy/translation/evaluation_metrics.py +++ b/viscy/translation/evaluation_metrics.py @@ -1,6 +1,7 @@ -"""Metrics for model evaluation.""" +"""Metrics for model evaluation""" -from collections.abc import Sequence +from typing import Sequence, Union +from warnings import warn import numpy as np import torch @@ -10,25 +11,24 @@ from scipy.optimize import linear_sum_assignment from skimage.measure import label, regionprops from torchmetrics.detection.mean_ap import MeanAveragePrecision +from torchvision.ops import masks_to_boxes -def VOI_metric(target: NDArray, prediction: NDArray) -> list[float]: - """ - Variation of information metric. +def VOI_metric(target: np.array, prediction: np.array) -> float: + """Variation of information metric - Reports overlap between predicted and ground truth mask. + Reports overlap between predicted and ground truth mask Parameters ---------- - target : NDArray - Ground truth mask. - prediction : NDArray - Model inferred FL image cellpose mask. + target : np.array + Ground truth mask + prediction : np.array + Model inferred FL image cellpose mask Returns ------- - list[float] - VI for image masks. + float VI: VI for image masks """ # cellpose segmentation of predicted image: outputs labl mask pred_bin = prediction > 0 @@ -65,9 +65,7 @@ def VOI_metric(target: NDArray, prediction: NDArray) -> list[float]: return [VI] -def POD_metric( - target_bin: NDArray, pred_bin: NDArray -) -> tuple[float, float, float, int, int]: +def POD_metric(target_bin: NDArray, pred_bin: NDArray): """ Probability of detection metric for object matching. @@ -132,260 +130,160 @@ def POD_metric( matching_targ.append(rid) matching_pred.append(cid) - # probability of detection - POD = len(matching_targ) / len(props_targ) - - # probability of false alarm - FAR = (len(props_pred) - len(matching_pred)) / len(props_pred) - - # probability of correct detection - PCD = len(matching_targ) / len(props_targ) + true_positives = len(matching_pred) + false_positives = n_predObj - len(matching_pred) + false_negatives = n_targObj - len(matching_targ) + precision = true_positives / (true_positives + false_positives) + recall = true_positives / (true_positives + false_negatives) + f1_score = 2 * (precision * recall / (precision + recall)) - return (POD, FAR, PCD, len(props_targ), len(props_pred)) + return [ + true_positives, + false_positives, + false_negatives, + precision, + recall, + f1_score, + ] -def compute_3d_dice_score( - y_true: torch.Tensor, - y_pred: torch.Tensor, - eps: float = 1e-8, - threshold: float = 0.5, - aggregate: bool = True, -) -> torch.Tensor: - """Compute 3D Dice similarity coefficient. +def labels_to_masks(labels: torch.ShortTensor) -> torch.BoolTensor: + """Convert integer labels to a stack of boolean masks. Parameters ---------- - y_true : torch.Tensor - True labels. - y_pred : torch.Tensor - Predicted labels. - eps : float, optional - Epsilon to avoid division by zero. Defaults to 1e-8. - threshold : float, optional - Threshold for binarization. Defaults to 0.5. - aggregate : bool, optional - Whether to aggregate the dice score. Defaults to True. + labels : torch.ShortTensor + 2D labels where each value is an object (0 is background) Returns ------- - torch.Tensor - Dice score. - """ - y_pred_thresholded = (y_pred > threshold).float() - intersection = torch.sum(y_true * y_pred_thresholded, dim=(-3, -2, -1)) - total = torch.sum(y_true + y_pred_thresholded, dim=(-3, -2, -1)) - dice = (2.0 * intersection + eps) / (total + eps) - if aggregate: - return torch.mean(dice) - return dice - - -def compute_jaccard_index( - y_true: torch.Tensor, - y_pred: torch.Tensor, - threshold: float = 0.5, -) -> torch.Tensor: - """Compute Jaccard index (IoU). + torch.BoolTensor + Boolean masks of shape (objects, H, W) - Parameters - ---------- - y_true : torch.Tensor - True labels. - y_pred : torch.Tensor - Predicted labels. - threshold : float, optional - Threshold for binarization. Defaults to 0.5. - - Returns - ------- - torch.Tensor - Jaccard index. """ - y_pred_thresholded = y_pred > threshold - intersection = torch.sum(y_true & y_pred_thresholded, dim=(-3, -2, -1)) - union = torch.sum(y_true | y_pred_thresholded, dim=(-3, -2, -1)) - return torch.mean(intersection.float() / union.float()) + if labels.ndim != 2: + raise ValueError(f"Labels must be 2D, got shape {labels.shape}.") + segments = torch.unique(labels) + n_instances = segments.numel() - 1 + masks = torch.zeros( + (n_instances, *labels.shape), dtype=torch.bool, device=labels.device + ) + # TODO: optimize this? + for s, segment in enumerate(segments): + # start from label value 1, i.e. skip background label + masks[s - 1] = labels == segment + return masks -def compute_pearson_correlation_coefficient( - y_true: torch.Tensor, y_pred: torch.Tensor, dim: Sequence[int] | None = None -) -> torch.Tensor: - """Compute Pearson correlation coefficient. +def labels_to_detection(labels: torch.ShortTensor) -> dict[str, torch.Tensor]: + """Convert integer labels to a torchvision/torchmetrics detection dictionary. Parameters ---------- - y_true : torch.Tensor - True labels. - y_pred : torch.Tensor - Predicted labels. - dim : Sequence[int] | None, optional - Dimensions to compute the Pearson correlation coefficient. Defaults to None. + labels : torch.ShortTensor + 2D labels where each value is an object (0 is background) Returns ------- - torch.Tensor - Pearson correlation coefficient. + dict[str, torch.Tensor] + detection boxes, scores, labels, and masks """ - if dim is None: - # default to spatial dimensions - dim = (-3, -2, -1) - y_true_centered = y_true - torch.mean(y_true, dim=dim, keepdim=True) - y_pred_centered = y_pred - torch.mean(y_pred, dim=dim, keepdim=True) - numerator = torch.sum(y_true_centered * y_pred_centered, dim=dim) - # compute stds - y_true_std = torch.sqrt(torch.sum(y_true_centered**2, dim=dim)) - y_pred_std = torch.sqrt(torch.sum(y_pred_centered**2, dim=dim)) - denominator = y_true_std * y_pred_std - # torch.full_like makes the entire tensor have the same value, - # so we have to use torch.full instead - small_correlation = torch.abs(denominator) < 1e-8 - pcc = torch.where( - small_correlation, torch.zeros_like(numerator), numerator / denominator - ) - return torch.mean(pcc) - - -class MeanAveragePrecisionNuclei(MeanAveragePrecision): - """Mean Average Precision for nuclei detection. + masks = labels_to_masks(labels) + boxes = masks_to_boxes(masks) + return { + "boxes": boxes, + # dummy confidence scores + "scores": torch.ones( + (boxes.shape[0],), dtype=torch.float32, device=boxes.device + ), + # dummy class labels + "labels": torch.zeros( + (boxes.shape[0],), dtype=torch.uint8, device=boxes.device + ), + "masks": masks, + } + + +def mean_average_precision( + pred_labels: torch.ShortTensor, target_labels: torch.ShortTensor, **kwargs +) -> dict[str, torch.Tensor]: + """Compute the mAP metric for instance segmentation. Parameters ---------- - min_area : int, optional - Minimum area of nuclei to be considered. Defaults to 20. - iou_threshold : float, optional - IoU threshold for matching. Defaults to 0.5. + pred_labels : torch.ShortTensor + 2D integer prediction labels + target_labels : torch.ShortTensor + 2D integer prediction labels + **kwargs : dict + Keyword arguments passed to + :py:class:`torchmetrics.detection.MeanAveragePrecision` Returns ------- - torch.Tensor - Mean average precision score. + dict[str, torch.Tensor] + COCO-style metrics """ + defaults = dict( + iou_type="segm", box_format="xyxy", max_detection_thresholds=[1, 100, 10000] + ) + if not kwargs: + kwargs = {} + map_metric = MeanAveragePrecision(**(defaults | kwargs)) + map_metric.update( + [labels_to_detection(pred_labels)], [labels_to_detection(target_labels)] + ) + return map_metric.compute() + - def __init__(self, min_area: int = 20, iou_threshold: float = 0.5) -> None: - super().__init__(iou_thresholds=[iou_threshold]) - self.min_area = min_area - - def __call__(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute mean average precision for nuclei detection. - - Parameters - ---------- - prediction : torch.Tensor - Predicted nuclei segmentation masks. - target : torch.Tensor - Ground truth nuclei segmentation masks. - - Returns - ------- - torch.Tensor - Mean average precision score. - """ - prediction_labels = label(prediction > 0.5) - target_labels = label(target > 0.5) - device = prediction.device - preds = [] - targets = [] - for i, (pred_img, target_img) in enumerate( - zip(prediction_labels, target_labels) - ): - pred_props = regionprops(pred_img) - # binary mask for each instance - pred_masks = torch.zeros( - len(pred_props), *pred_img.shape, dtype=torch.bool, device=device - ) - pred_labels = torch.zeros(len(pred_props), dtype=torch.long, device=device) - pred_scores = torch.ones(len(pred_props), dtype=torch.float, device=device) - for j, prop in enumerate(pred_props): - if prop.area < self.min_area: - continue - pred_masks[j, pred_img == prop.label] = True - pred_labels[j] = 1 # class 1 for nuclei - - target_props = regionprops(target_img) - target_masks = torch.zeros( - len(target_props), *target_img.shape, dtype=torch.bool, device=device - ) - target_labels = torch.zeros( - len(target_props), dtype=torch.long, device=device - ) - for j, prop in enumerate(target_props): - if prop.area < self.min_area: - continue - target_masks[j, target_img == prop.label] = True - target_labels[j] = 1 - - preds.append( - { - "masks": pred_masks, - "labels": pred_labels, - "scores": pred_scores, - } - ) - targets.append({"masks": target_masks, "labels": target_labels}) - return super().__call__(preds, targets) - - -def ssim_loss_25d( +def ssim_25d( preds: torch.Tensor, target: torch.Tensor, in_plane_window_size: tuple[int, int] = (11, 11), return_contrast_sensitivity: bool = False, -) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """ - Multi-scale SSIM loss function for 2.5D volumes (3D with small depth). +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Multi-scale SSIM loss function for 2.5D volumes (3D with small depth). Uses uniform kernel (windows), depth-dimension window size equals to depth size. Parameters ---------- preds : torch.Tensor - Predicted batch (B, C, D, W, H). + predicted batch (B, C, D, W, H) target : torch.Tensor - Target batch. + target batch in_plane_window_size : tuple[int, int], optional - Kernel width and height, by default (11, 11). + kernel width and height, by default (11, 11) return_contrast_sensitivity : bool, optional - Whether to return contrast sensitivity, by default False. + whether to return contrast sensitivity Returns ------- - torch.Tensor | tuple[torch.Tensor, torch.Tensor] - SSIM for the batch, optionally with contrast sensitivity. + torch.Tensor: SSIM for the batch + Optional[torch.Tensor]: contrast sensitivity """ if preds.ndim != 5: raise ValueError( - f"Expected preds to have 5 dimensions (B, C, D, W, H), got {preds.ndim}" - ) - if preds.shape != target.shape: - raise ValueError( - f"Expected preds and target to have the same shape, " - f"got {preds.shape} and {target.shape}" + f"Input shape must be (B, C, D, W, H), got input shape {preds.shape}" ) - - B, C, D, H, W = preds.shape - # Compute SSIM for each channel and each depth slice - ssim_per_channel = [] - cs_per_channel = [] - - for c in range(C): - # Window size for depth dimension is the depth size - window_size = (*in_plane_window_size, D) - ssim, cs = compute_ssim_and_cs( - preds[:, c, :, :, :], target[:, c, :, :, :], window_size - ) - ssim_per_channel.append(ssim) - if return_contrast_sensitivity: - cs_per_channel.append(cs) - - # Average across channels - ssim_result = torch.mean(torch.stack(ssim_per_channel)) - + depth = preds.shape[2] + if depth > 15: + warn(f"Input depth {depth} is potentially too large for 2.5D SSIM.") + ssim_img, cs_img = compute_ssim_and_cs( + preds, + target, + 3, + kernel_sigma=None, + kernel_size=(depth, *in_plane_window_size), + data_range=target.max(), + kernel_type="uniform", + ) + # aggregate to one scalar per batch + ssim = ssim_img.view(ssim_img.shape[0], -1).mean(1) if return_contrast_sensitivity: - cs_result = torch.mean(torch.stack(cs_per_channel)) - return ssim_result, cs_result - - return ssim_result + return ssim, cs_img.view(cs_img.shape[0], -1).mean(1) + else: + return ssim def ms_ssim_25d( @@ -395,8 +293,7 @@ def ms_ssim_25d( clamp: bool = False, betas: Sequence[float] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333), ) -> torch.Tensor: - """ - Multi-scale SSIM for 2.5D volumes (3D with small depth). + """Multi-scale SSIM for 2.5D volumes (3D with small depth). Uses uniform kernel (windows), depth-dimension window size equals to depth size. Depth dimension is not downsampled. @@ -408,69 +305,37 @@ def ms_ssim_25d( Parameters ---------- preds : torch.Tensor - Predicted images. + predicted images target : torch.Tensor - Target images. + target images in_plane_window_size : tuple[int, int], optional - Kernel width and height, defaults to (11, 11). + kernel width and height, by default (11, 11) clamp : bool, optional - Clamp to [1e-6, 1] for training stability when used in loss, - defaults to False. + clamp to [1e-6, 1] for training stability when used in loss, + by default False betas : Sequence[float], optional - Exponents of each resolution, - defaults to (0.0448, 0.2856, 0.3001, 0.2363, 0.1333). + exponents of each resolution, by default (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) Returns ------- - torch.Tensor - Multi-scale SSIM. + torch.Tensor: multi-scale SSIM """ base_min = 1e-4 mcs_list = [] - ssim_list = [] - - B, C, D, H, W = preds.shape - - for c in range(C): - # Window size for depth dimension is the depth size - window_size = (*in_plane_window_size, D) - - pred_c = preds[:, c] - target_c = target[:, c] - - for level in range(len(betas)): - if level > 0: - # Downsample only in spatial dimensions, not depth - pred_c = F.avg_pool2d(pred_c.view(-1, H, W), kernel_size=2).view( - B, D, H // 2, W // 2 - ) - target_c = F.avg_pool2d(target_c.view(-1, H, W), kernel_size=2).view( - B, D, H // 2, W // 2 - ) - H, W = H // 2, W // 2 - - ssim, cs = compute_ssim_and_cs(pred_c, target_c, window_size) - - if level == len(betas) - 1: - ssim_list.append(ssim) - else: - mcs_list.append(cs) - - # Compute the final ms-ssim score - mcs_tensor = torch.stack(mcs_list) - ssim_tensor = torch.stack(ssim_list) - - # Apply betas weighting - betas_tensor = torch.tensor(betas, device=preds.device, dtype=preds.dtype) - - # For numerical stability + for _ in range(len(betas)): + ssim, contrast_sensitivity = ssim_25d( + preds, target, in_plane_window_size, return_contrast_sensitivity=True + ) + if clamp: + contrast_sensitivity = contrast_sensitivity.clamp(min=base_min) + mcs_list.append(contrast_sensitivity) + # do not downsample along depth + preds = F.avg_pool3d(preds, (1, 2, 2)) + target = F.avg_pool3d(target, (1, 2, 2)) if clamp: - mcs_tensor = torch.clamp(mcs_tensor, base_min, 1) - ssim_tensor = torch.clamp(ssim_tensor, base_min, 1) - - # Compute weighted geometric mean - ms_ssim_val = torch.prod(mcs_tensor ** betas_tensor[:-1]) * ( - ssim_tensor ** betas_tensor[-1] - ) - - return torch.mean(ms_ssim_val) + ssim = ssim.clamp(min=base_min) + mcs_list[-1] = ssim + mcs_stack = torch.stack(mcs_list) + betas = torch.tensor(betas, device=mcs_stack.device).view(-1, 1) + mcs_weighted = mcs_stack**betas + return torch.prod(mcs_weighted, axis=0).mean() diff --git a/viscy/translation/predict_writer.py b/viscy/translation/predict_writer.py index 29c0e41be..4aaa3394a 100644 --- a/viscy/translation/predict_writer.py +++ b/viscy/translation/predict_writer.py @@ -1,5 +1,3 @@ -"""Prediction writer for HCS virtual staining predictions in OME-Zarr format.""" - import logging import os from collections.abc import Sequence diff --git a/viscy/unet/__init__.py b/viscy/unet/__init__.py index 5e46d0396..e69de29bb 100644 --- a/viscy/unet/__init__.py +++ b/viscy/unet/__init__.py @@ -1 +0,0 @@ -"""U-Net architectures for VisCy.""" diff --git a/viscy/unet/networks/Unet25D.py b/viscy/unet/networks/Unet25D.py index 9d9c6fb1a..9cef5e93b 100644 --- a/viscy/unet/networks/Unet25D.py +++ b/viscy/unet/networks/Unet25D.py @@ -1,5 +1,3 @@ -"""2.5D U-Net implementation for volumetric image processing.""" - from typing import Literal import torch diff --git a/viscy/unet/networks/__init__.py b/viscy/unet/networks/__init__.py index c21eefda3..e69de29bb 100644 --- a/viscy/unet/networks/__init__.py +++ b/viscy/unet/networks/__init__.py @@ -1 +0,0 @@ -"""Neural network architectures for VisCy U-Net implementations.""" diff --git a/viscy/unet/networks/layers/ConvBlock2D.py b/viscy/unet/networks/layers/ConvBlock2D.py index a713edfad..7f9861fe5 100644 --- a/viscy/unet/networks/layers/ConvBlock2D.py +++ b/viscy/unet/networks/layers/ConvBlock2D.py @@ -1,5 +1,3 @@ -"""2D convolutional blocks for U-Net architectures.""" - from typing import Literal import numpy as np diff --git a/viscy/unet/networks/layers/ConvBlock3D.py b/viscy/unet/networks/layers/ConvBlock3D.py index 6895d8baa..9cf96e33c 100644 --- a/viscy/unet/networks/layers/ConvBlock3D.py +++ b/viscy/unet/networks/layers/ConvBlock3D.py @@ -1,5 +1,3 @@ -"""3D convolutional blocks for volumetric neural network architectures.""" - from typing import Literal import numpy as np diff --git a/viscy/unet/networks/layers/__init__.py b/viscy/unet/networks/layers/__init__.py index 54f8699e9..e69de29bb 100644 --- a/viscy/unet/networks/layers/__init__.py +++ b/viscy/unet/networks/layers/__init__.py @@ -1 +0,0 @@ -"""Convolutional building blocks for neural network layers.""" diff --git a/viscy/unet/networks/unext2.py b/viscy/unet/networks/unext2.py index cf4b925da..ed683a602 100644 --- a/viscy/unet/networks/unext2.py +++ b/viscy/unet/networks/unext2.py @@ -1,5 +1,3 @@ -"""UNeXt2: ConvNeXt-based U-Net implementation for advanced neural network architectures.""" - from collections.abc import Callable, Sequence from typing import Literal diff --git a/viscy/utils/__init__.py b/viscy/utils/__init__.py index b0ba3e69f..e69de29bb 100644 --- a/viscy/utils/__init__.py +++ b/viscy/utils/__init__.py @@ -1 +0,0 @@ -"""Module for utility functions.""" diff --git a/viscy/utils/cli_utils.py b/viscy/utils/cli_utils.py index 96373fe70..9d81b3335 100644 --- a/viscy/utils/cli_utils.py +++ b/viscy/utils/cli_utils.py @@ -1,5 +1,3 @@ -"""Command-line interface utilities for data processing and visualization.""" - import collections import os import re diff --git a/viscy/utils/log_images.py b/viscy/utils/log_images.py index 3d26a5b7b..99dfa98f3 100644 --- a/viscy/utils/log_images.py +++ b/viscy/utils/log_images.py @@ -1,5 +1,3 @@ -"""Logging example images during training.""" - from collections.abc import Sequence import numpy as np diff --git a/viscy/utils/logging.py b/viscy/utils/logging.py index 25b9897fe..7320fc5b2 100644 --- a/viscy/utils/logging.py +++ b/viscy/utils/logging.py @@ -1,5 +1,3 @@ -"""Feature map logging utilities for neural network debugging.""" - import datetime import os import time diff --git a/viscy/utils/masks.py b/viscy/utils/masks.py index aba09fc83..f000fce94 100644 --- a/viscy/utils/masks.py +++ b/viscy/utils/masks.py @@ -1,5 +1,3 @@ -"""Mask generation and processing utilities.""" - from typing import Any import numpy as np diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index 0ede1d85d..65286a33e 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -1,5 +1,3 @@ -"""Metadata utilities for dataset analysis and normalization statistics.""" - import os import sys from pathlib import Path @@ -70,7 +68,7 @@ def _grid_sample( def generate_normalization_metadata( - zarr_dir: str | Path, + zarr_dir: str, num_workers: int = 4, channel_ids: list[int] | int = -1, grid_spacing: int = 32, @@ -200,7 +198,7 @@ def generate_normalization_metadata( def compute_normalization_stats( - image_data: NDArray, grid_spacing: int = 32 + image_data: np.ndarray, grid_spacing: int = 32 ) -> dict[str, float]: """Compute normalization statistics from image data using grid sampling. @@ -209,7 +207,7 @@ def compute_normalization_stats( image_data : np.ndarray 3D or 4D image array of shape (z, y, x) or (t, z, y, x). grid_spacing : int, optional - Spacing between grid points for sampling, by default 32. + Spacing betweend grid points for sampling, by default 32. Returns ------- diff --git a/viscy/utils/mp_utils.py b/viscy/utils/mp_utils.py index a37f059a7..65fc78071 100644 --- a/viscy/utils/mp_utils.py +++ b/viscy/utils/mp_utils.py @@ -1,5 +1,3 @@ -"""Multiprocessing utilities for parallel data processing.""" - from collections.abc import Callable from concurrent.futures import ProcessPoolExecutor from typing import Any diff --git a/viscy/utils/slurm_utils.py b/viscy/utils/slurm_utils.py index f2fc2bdf8..c943b8b11 100644 --- a/viscy/utils/slurm_utils.py +++ b/viscy/utils/slurm_utils.py @@ -1,5 +1,3 @@ -"""SLURM cluster utilities for resource management.""" - import psutil import torch From ecc3296eb06ff87f5d5ed70f9ec897b4b1c029e5 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Thu, 11 Sep 2025 11:31:35 -0700 Subject: [PATCH 05/13] hcs_stack -> tuple[str,int,int] --- viscy/data/hcs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 785e24394..6ace39acd 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -201,7 +201,7 @@ def _find_window(self, index: int) -> tuple[ImageArray, int, NormMeta | None]: def _read_img_window( self, img: ImageArray, ch_idx: list[int], tz: int - ) -> tuple[tuple[Tensor, ...], HCSStackIndex]: + ) -> tuple[tuple[Tensor, ...], tuple[str, int, int]]: """Read image window as tensor. Parameters @@ -215,7 +215,7 @@ def _read_img_window( Returns ------- - tuple[tuple[Tensor], HCSStackIndex] + tuple[tuple[Tensor], tuple[str, int, int]] list of (C=1, Z, Y, X) image tensors, tuple of image name, time index, and Z index @@ -232,7 +232,7 @@ def _read_img_window( [int(i) for i in ch_idx], slice(z, z + self.z_window_size), ].astype(np.float32) - return torch.from_numpy(data).unbind(dim=1), HCSStackIndex(img.name, t, z) + return torch.from_numpy(data).unbind(dim=1), (img.name, t, z) def __len__(self) -> int: """Return total number of sliding windows across all FOVs.""" From f0df1626f3cd00ce9b2f597dafdeea6f2df04eff Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Thu, 11 Sep 2025 13:26:55 -0700 Subject: [PATCH 06/13] updated meta_utils.py --- viscy/data/hcs.py | 5 +- viscy/utils/meta_utils.py | 260 ++++++++++++++++++-------------------- 2 files changed, 123 insertions(+), 142 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 6ace39acd..e05eea269 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -3,9 +3,10 @@ import os import re import tempfile -from collections.abc import Callable, Sequence + +# from collections.abc import Callable, Sequence from pathlib import Path -from typing import Literal +from typing import Callable, Literal, Sequence import numpy as np import torch diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index 65286a33e..3547c39a7 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -1,6 +1,5 @@ import os import sys -from pathlib import Path import iohub.ngff as ngff import numpy as np @@ -11,7 +10,9 @@ from viscy.utils.mp_utils import get_val_stats -def write_meta_field(position: ngff.Position, metadata, field_name, subfield_name): +def write_meta_field( + position: ngff.Position, metadata: dict, field_name: str, subfield_name: str +): """Write metadata to position's plate-level or FOV level .zattrs metadata. Write metadata to position's plate-level or FOV level .zattrs metadata by either @@ -68,21 +69,13 @@ def _grid_sample( def generate_normalization_metadata( - zarr_dir: str, - num_workers: int = 4, - channel_ids: list[int] | int = -1, - grid_spacing: int = 32, + zarr_dir: str, num_workers: int = 4, channel_ids: int = -1, grid_spacing: int = 32 ): """Generate pixel intensity metadata for on-the-fly normalization. Generate pixel intensity metadata to be later used in on-the-fly normalization during training and inference. Sampling is used for efficient estimation of median and interquartile range for intensity values on both a dataset and field-of-view - level. - - Normalization values are recorded in the image-level metadata in the corresponding - position of each zarr_dir store. Format of metadata is as follows: - { channel_idx : { dataset_statistics: dataset level normalization values (positive float), fov_statistics: field-of-view level normalization values (positive float) @@ -106,152 +99,139 @@ def generate_normalization_metadata( plate = ngff.open_ome_zarr(zarr_dir, mode="r+") position_map = list(plate.positions()) - # Prepare parameters for multiprocessing - zarr_dir_path = os.path.dirname(os.path.dirname(zarr_dir)) - - # Get channels to process if channel_ids == -1: - # Get channel IDs from first position - first_position = position_map[0][1] - first_images = list(first_position.images()) - first_image = first_images[0][1] - # shape is (t, c, z, y, x) - channel_ids = list(range(first_image.data.shape[1])) - - if isinstance(channel_ids, int): + channel_ids = range(len(plate.channel_names)) + elif isinstance(channel_ids, int): channel_ids = [channel_ids] - # Prepare parameters for each position and channel - params_list = [] - for position_idx, (position_key, position) in enumerate(position_map): - for channel_id in channel_ids: - params = { - "zarr_dir": zarr_dir, - "position_key": position_key, - "channel_id": channel_id, - "grid_spacing": grid_spacing, - } - params_list.append(params) - - # Use multiprocessing to compute normalization statistics - progress_bar = show_progress_bar() - if num_workers > 1: - with mp_utils.get_context("spawn").Pool(num_workers) as pool: - results = pool.map(mp_utils.normalize_meta_worker, params_list) - progress_bar.update(len(params_list)) - else: - results = [] - for params in params_list: - result = mp_utils.normalize_meta_worker(params) - results.append(result) - progress_bar.update(1) - - progress_bar.close() - - # Aggregate results and write to metadata - all_dataset_stats = {} - for result in results: - if result is not None: - position_key, channel_id, dataset_stats, fov_stats = result - - if channel_id not in all_dataset_stats: - all_dataset_stats[channel_id] = [] - all_dataset_stats[channel_id].append(dataset_stats) - - # Calculate dataset-level statistics - final_dataset_stats = {} - for channel_id, stats_list in all_dataset_stats.items(): - if stats_list: - # Aggregate median and IQR across all positions - medians = [stats["median"] for stats in stats_list if "median" in stats] - iqrs = [stats["iqr"] for stats in stats_list if "iqr" in stats] - - if medians and iqrs: - final_dataset_stats[channel_id] = { - "median": np.median(medians), - "iqr": np.median(iqrs), - } - - # Write metadata to each position - for result in results: - if result is not None: - position_key, channel_id, dataset_stats, fov_stats = result - - # Get position object - position = dict(plate.positions())[position_key] - - # Prepare metadata - metadata = { - "dataset_statistics": final_dataset_stats.get(channel_id, {}), - "fov_statistics": fov_stats, - } + # get arguments for multiprocessed grid sampling + mp_grid_sampler_args = [] + for _, position in position_map: + mp_grid_sampler_args.append([position, grid_spacing]) + + # sample values and use them to get normalization statistics + for i, channel_index in enumerate(channel_ids): + print(f"Sampling channel index {channel_index} ({i + 1}/{len(channel_ids)})") - # Write metadata + channel_name = plate.channel_names[channel_index] + dataset_sample_values = [] + position_and_statistics = [] + + for _, pos in tqdm(position_map, desc="Positions"): + samples = _grid_sample(pos, grid_spacing, channel_index, num_workers) + dataset_sample_values.append(samples) + fov_level_statistics = {"fov_statistics": get_val_stats(samples)} + position_and_statistics.append((pos, fov_level_statistics)) + + dataset_statistics = { + "dataset_statistics": get_val_stats(np.stack(dataset_sample_values)), + } + write_meta_field( + position=plate, + metadata=dataset_statistics, + field_name="normalization", + subfield_name=channel_name, + ) + + for pos, position_statistics in position_and_statistics: write_meta_field( - position=position, - metadata=metadata, + position=pos, + metadata=dataset_statistics | position_statistics, field_name="normalization", - subfield_name=str(channel_id), + subfield_name=channel_name, ) plate.close() -def compute_normalization_stats( - image_data: np.ndarray, grid_spacing: int = 32 -) -> dict[str, float]: +def compute_zscore_params( + frames_meta, ints_meta, input_dir, normalize_im, min_fraction=0.99 +): """Compute normalization statistics from image data using grid sampling. + Compute zscore median and interquartile range. + Parameters ---------- - image_data : np.ndarray - 3D or 4D image array of shape (z, y, x) or (t, z, y, x). - grid_spacing : int, optional - Spacing betweend grid points for sampling, by default 32. + frames_meta : pd.DataFrame + Dataframe containing all metadata. + ints_meta : pd.DataFrame + Metadata containing intensity statistics each z-slice and foreground fraction for masks. + input_dir : str + Directory containing images. + normalize_im : None or str + Normalization scheme for input images. + min_fraction : float + Minimum foreground fraction (in case of masks) for computing intensity statistics. + for computing intensity statistics. Returns ------- - dict[str, float] - Dictionary with median and IQR statistics for normalization. + tuple[pd.DataFrame, pd.DataFrame] + Tuple containing: + - pd.DataFrame frames_meta: Dataframe containing all metadata + - pd.DataFrame ints_meta: Metadata containing intensity statistics of each z-slice """ - # Handle different input shapes - if image_data.ndim == 4: - # Assume (t, z, y, x) and take first timepoint - image_data = image_data[0] - - if image_data.ndim == 3: - # Assume (z, y, x) and use middle z-slice if available - if image_data.shape[0] > 1: - z_mid = image_data.shape[0] // 2 - image_data = image_data[z_mid] - else: - image_data = image_data[0] - - # Now image_data should be 2D (y, x) - if image_data.ndim != 2: - raise ValueError(f"Expected 2D image after processing, got {image_data.ndim}D") - - # Create sampling grid - y_indices = np.arange(0, image_data.shape[0], grid_spacing) - x_indices = np.arange(0, image_data.shape[1], grid_spacing) - - # Sample values at grid points - sampled_values = image_data[np.ix_(y_indices, x_indices)].flatten() - - # Remove any NaN or infinite values - sampled_values = sampled_values[np.isfinite(sampled_values)] - - if len(sampled_values) == 0: - return {"median": 0.0, "iqr": 1.0} - - # Compute statistics - median = np.median(sampled_values) - q25 = np.percentile(sampled_values, 25) - q75 = np.percentile(sampled_values, 75) - iqr = q75 - q25 - - # Avoid zero IQR - if iqr == 0: - iqr = 1.0 + assert normalize_im in [ + None, + "slice", + "volume", + "dataset", + ], 'normalize_im must be None or "slice" or "volume" or "dataset"' + + if normalize_im is None: + # No normalization + frames_meta["zscore_median"] = 0 + frames_meta["zscore_iqr"] = 1 + return frames_meta + elif normalize_im == "dataset": + agg_cols = ["time_idx", "channel_idx", "dir_name"] + elif normalize_im == "volume": + agg_cols = ["time_idx", "channel_idx", "dir_name", "pos_idx"] + else: + agg_cols = ["time_idx", "channel_idx", "dir_name", "pos_idx", "slice_idx"] + # median and inter-quartile range are more robust than mean and std + ints_meta_sub = ints_meta[ints_meta["fg_frac"] >= min_fraction] + ints_agg_median = ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).median() + ints_agg_hq = ( + ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).quantile(0.75) + ) + ints_agg_lq = ( + ints_meta_sub[agg_cols + ["intensity"]].groupby(agg_cols).quantile(0.25) + ) + ints_agg = ints_agg_median + ints_agg.columns = ["zscore_median"] + ints_agg["zscore_iqr"] = ints_agg_hq["intensity"] - ints_agg_lq["intensity"] + ints_agg.reset_index(inplace=True) + + cols_to_merge = frames_meta.columns[ + [col not in ["zscore_median", "zscore_iqr"] for col in frames_meta.columns] + ] + frames_meta = pd.merge( + frames_meta[cols_to_merge], + ints_agg, + how="left", + on=agg_cols, + ) + if frames_meta["zscore_median"].isnull().values.any(): + raise ValueError( + "Found NaN in normalization parameters. \ + min_fraction might be too low or images might be corrupted." + ) + frames_meta_filename = os.path.join(input_dir, "frames_meta.csv") + frames_meta.to_csv(frames_meta_filename, sep=",") + + cols_to_merge = ints_meta.columns[ + [col not in ["zscore_median", "zscore_iqr"] for col in ints_meta.columns] + ] + ints_meta = pd.merge( + ints_meta[cols_to_merge], + ints_agg, + how="left", + on=agg_cols, + ) + ints_meta["intensity_norm"] = ( + ints_meta["intensity"] - ints_meta["zscore_median"] + ) / (ints_meta["zscore_iqr"] + sys.float_info.epsilon) - return {"median": float(median), "iqr": float(iqr)} + return frames_meta, ints_meta From 936a9543a00749d96c45e9339a99c5a91876140c Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Fri, 12 Sep 2025 10:16:23 -0700 Subject: [PATCH 07/13] reverted some typhints --- viscy/representation/embedding_writer.py | 5 +-- viscy/representation/evaluation/__init__.py | 6 +-- viscy/representation/evaluation/clustering.py | 8 +++- .../evaluation/dimensionality_reduction.py | 31 ++++++++------ viscy/representation/evaluation/distance.py | 42 +++++++++++++------ viscy/representation/evaluation/lca.py | 6 +-- 6 files changed, 62 insertions(+), 36 deletions(-) diff --git a/viscy/representation/embedding_writer.py b/viscy/representation/embedding_writer.py index 0e4db683c..e7a89c079 100644 --- a/viscy/representation/embedding_writer.py +++ b/viscy/representation/embedding_writer.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd import torch -import xarray as xr from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import NDArray @@ -23,7 +22,7 @@ _logger = logging.getLogger("lightning.pytorch") -def read_embedding_dataset(path: Path) -> xr.Dataset: +def read_embedding_dataset(path: Path) -> Dataset: """Read the embedding dataset written by the EmbeddingWriter callback. Supports both legacy datasets (without x/y coordinates) and new datasets. @@ -35,7 +34,7 @@ def read_embedding_dataset(path: Path) -> xr.Dataset: Returns ------- - xr.Dataset + Dataset Xarray dataset with features and projections. """ dataset = open_zarr(path) diff --git a/viscy/representation/evaluation/__init__.py b/viscy/representation/evaluation/__init__.py index 36d899b2b..1c43f6295 100644 --- a/viscy/representation/evaluation/__init__.py +++ b/viscy/representation/evaluation/__init__.py @@ -18,19 +18,19 @@ from pathlib import Path import pandas as pd -import xarray as xr from viscy.data.triplet import TripletDataModule +from xarray import DataArray def load_annotation( - da: xr.DataArray, path: str, name: str, categories: dict | None = None + da: DataArray, path: str, name: str, categories: dict | None = None ) -> pd.Series: """ Load annotations from a CSV file and map them to the dataset. Parameters ---------- - da : xr.DataArray + da : DataArray The dataset array containing 'fov_name' and 'id' coordinates. path : str Path to the CSV file containing annotations. diff --git a/viscy/representation/evaluation/clustering.py b/viscy/representation/evaluation/clustering.py index afe09a8cd..c5599229b 100644 --- a/viscy/representation/evaluation/clustering.py +++ b/viscy/representation/evaluation/clustering.py @@ -1,3 +1,5 @@ +"""Methods for evaluating clustering performance.""" + import numpy as np from numpy.typing import ArrayLike, NDArray from scipy.spatial.distance import cdist @@ -10,12 +12,16 @@ from sklearn.neighbors import KNeighborsClassifier -def knn_accuracy(embeddings, annotations, k=5): +def knn_accuracy(embeddings: NDArray, annotations: NDArray, k: int = 5) -> float: """ Evaluate the k-NN classification accuracy. Parameters ---------- + embeddings : NDArray + Embeddings to cluster. + annotations : NDArray + Ground truth labels. k : int, optional Number of neighbors to use for k-NN. Default is 5. diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index ed1c9c47c..290259eb2 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -1,25 +1,30 @@ +from typing import TYPE_CHECKING + import pandas as pd import umap -import xarray as xr from numpy.typing import NDArray from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler +from xarray import Dataset + +if TYPE_CHECKING: + from phate import PHATE def compute_phate( - embedding_dataset: NDArray | xr.Dataset, + embedding_dataset: NDArray | Dataset, n_components: int = 2, knn: int = 5, decay: int = 40, update_dataset: bool = False, **phate_kwargs, -) -> tuple[object, NDArray]: +) -> tuple[PHATE, NDArray]: """ Compute PHATE embeddings for features and optionally update dataset. Parameters ---------- - embedding_dataset : xr.Dataset | NDArray + embedding_dataset : NDArray | Dataset The dataset containing embeddings, timepoints, fov_name, and track_id, or a numpy array of embeddings. n_components : int, optional @@ -35,7 +40,7 @@ def compute_phate( Returns ------- - tuple[object, NDArray] + tuple[phate.PHATE, NDArray] PHATE model and PHATE embeddings Raises @@ -53,7 +58,7 @@ def compute_phate( # Get embeddings from dataset if needed embeddings = ( embedding_dataset["features"].values - if isinstance(embedding_dataset, xr.Dataset) + if isinstance(embedding_dataset, Dataset) else embedding_dataset ) @@ -64,7 +69,7 @@ def compute_phate( phate_embedding = phate_model.fit_transform(embeddings) # Update dataset if requested - if update_dataset and isinstance(embedding_dataset, xr.Dataset): + if update_dataset and isinstance(embedding_dataset, Dataset): for i in range( min(2, phate_embedding.shape[1]) ): # Only update PHATE1 and PHATE2 @@ -73,12 +78,12 @@ def compute_phate( return phate_model, phate_embedding -def compute_pca(embedding_dataset, n_components=None, normalize_features=True): +def compute_pca(embedding_dataset: NDArray | Dataset, n_components=None, normalize_features=True): """Compute PCA embeddings for features and optionally update dataset. Parameters ---------- - embedding_dataset : xr.Dataset or NDArray + embedding_dataset : Dataset | NDArray The dataset containing embeddings, timepoints, fov_name, and track_id, or a numpy array of embeddings. n_components : int, optional @@ -93,7 +98,7 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True): """ embeddings = ( embedding_dataset["features"].values - if isinstance(embedding_dataset, xr.Dataset) + if isinstance(embedding_dataset, Dataset) else embedding_dataset ) @@ -107,7 +112,7 @@ def compute_pca(embedding_dataset, n_components=None, normalize_features=True): pc_features = PCA_features.fit_transform(scaled_features) # Create base dictionary with id and fov_name - if isinstance(embedding_dataset, xr.Dataset): + if isinstance(embedding_dataset, Dataset): pca_dict = { "id": embedding_dataset["id"].values, "fov_name": embedding_dataset["fov_name"].values, @@ -139,13 +144,13 @@ def _fit_transform_umap( def compute_umap( - embedding_dataset: xr.Dataset, normalize_features: bool = True + embedding_dataset: Dataset, normalize_features: bool = True ) -> tuple[umap.UMAP, umap.UMAP, pd.DataFrame]: """Compute UMAP embeddings for features and projections. Parameters ---------- - embedding_dataset : xr.Dataset + embedding_dataset : Dataset Xarray dataset with features and projections. normalize_features : bool, optional Scale the input to zero mean and unit variance before fitting UMAP, diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index fae2e2248..bc4d5d13a 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -2,14 +2,32 @@ from typing import Literal import numpy as np -import xarray as xr +from numpy.typing import NDArray from sklearn.metrics.pairwise import cosine_similarity +from xarray import Dataset def calculate_cosine_similarity_cell( - embedding_dataset: xr.Dataset, fov_name: str, track_id: int -): - """Extract embeddings and calculate cosine similarities for a specific cell""" + embedding_dataset: Dataset, fov_name: str, track_id: int +) -> tuple[NDArray, NDArray]: + """ + + Extract embeddings and calculate cosine similarities for a specific cell + + Parameters + ---------- + embedding_dataset : Dataset + Dataset containing embeddings and metadata + fov_name : str + Field of view identifier + track_id : int + Track identifier for the specific cell + + Returns + ------- + tuple[NDArray, NDArray] + Time points and cosine similarities for the specific cell + """ filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) & (embedding_dataset["track_id"] == track_id), @@ -25,7 +43,7 @@ def calculate_cosine_similarity_cell( def compute_displacement( - embedding_dataset: xr.Dataset, + embedding_dataset: Dataset, distance_metric: Literal["euclidean_squared", "cosine"] = "euclidean_squared", ) -> dict[int, list[float]]: """Compute the displacement or mean square displacement (MSD) of embeddings. @@ -37,15 +55,13 @@ def compute_displacement( Parameters ---------- - embedding_dataset : xarray.Dataset + embedding_dataset : Dataset Dataset containing embeddings and metadata - distance_metric : str + distance_metric : Literal["euclidean_squared", "cosine"] The metric to use for computing distances between embeddings. Valid options are: - - "euclidean": Euclidean distance (L2 norm) - "euclidean_squared": Squared Euclidean distance (for MSD, default) - "cosine": Cosine similarity - - "cosine_dissimilarity": 1 - cosine similarity Returns ------- @@ -152,13 +168,13 @@ def compute_dynamic_range(mean_displacement_per_tau: dict[int, float]): return max(displacements) - min(displacements) -def compute_rms_per_track(embedding_dataset: xr.Dataset): +def compute_rms_per_track(embedding_dataset: Dataset): """ Compute RMS of the time derivative of embeddings per track. Parameters ---------- - embedding_dataset : xarray.Dataset + embedding_dataset : Dataset The dataset containing embeddings, timepoints, fov_name, and track_id. Returns @@ -204,13 +220,13 @@ def compute_rms_per_track(embedding_dataset: xr.Dataset): def calculate_normalized_euclidean_distance_cell( - embedding_dataset: xr.Dataset, fov_name: str, track_id: int + embedding_dataset: Dataset, fov_name: str, track_id: int ): """Calculate normalized euclidean distance for a specific cell track. Parameters ---------- - embedding_dataset : xr.Dataset + embedding_dataset : Dataset Dataset containing embedding data with fov_name and track_id coordinates fov_name : str Field of view identifier diff --git a/viscy/representation/evaluation/lca.py b/viscy/representation/evaluation/lca.py index 89e2f6142..9090e069f 100644 --- a/viscy/representation/evaluation/lca.py +++ b/viscy/representation/evaluation/lca.py @@ -5,7 +5,6 @@ import pandas as pd import torch import torch.nn as nn -import xarray as xr from captum.attr import IntegratedGradients, Occlusion from numpy.typing import NDArray from sklearn.linear_model import LogisticRegression @@ -13,10 +12,11 @@ from sklearn.preprocessing import StandardScaler from torch import Tensor from viscy.representation.contrastive import ContrastiveEncoder +from xarray import DataArray def fit_logistic_regression( - features: xr.DataArray, + features: DataArray, annotations: pd.Series, train_fovs: list[str], remove_background_class: bool = True, @@ -32,7 +32,7 @@ def fit_logistic_regression( Parameters ---------- - features : xr.DataArray + features : DataArray Xarray of features. annotations : pd.Series Categorical class annotations with label values starting from 0. From e5a10e699d551b8873f55383dbc272f3d4e24d15 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Fri, 12 Sep 2025 10:17:54 -0700 Subject: [PATCH 08/13] ruff check --- viscy/representation/evaluation/dimensionality_reduction.py | 4 +++- viscy/representation/evaluation/distance.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/viscy/representation/evaluation/dimensionality_reduction.py b/viscy/representation/evaluation/dimensionality_reduction.py index 290259eb2..cc409c30b 100644 --- a/viscy/representation/evaluation/dimensionality_reduction.py +++ b/viscy/representation/evaluation/dimensionality_reduction.py @@ -78,7 +78,9 @@ def compute_phate( return phate_model, phate_embedding -def compute_pca(embedding_dataset: NDArray | Dataset, n_components=None, normalize_features=True): +def compute_pca( + embedding_dataset: NDArray | Dataset, n_components=None, normalize_features=True +): """Compute PCA embeddings for features and optionally update dataset. Parameters diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index bc4d5d13a..fd8df30af 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -11,9 +11,9 @@ def calculate_cosine_similarity_cell( embedding_dataset: Dataset, fov_name: str, track_id: int ) -> tuple[NDArray, NDArray]: """ - + Extract embeddings and calculate cosine similarities for a specific cell - + Parameters ---------- embedding_dataset : Dataset From d6fafb57a26ce83e1bd74a7e065c4f2175e1dc08 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Fri, 12 Sep 2025 10:37:11 -0700 Subject: [PATCH 09/13] missed a couple of docstrings --- pyproject.toml | 1 - tests/preprocessing/resize_images_tests.py | 4 ++-- viscy/data/combined.py | 15 +++++++++++++- viscy/data/hcs.py | 24 ++++++++++++++++------ 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 67a958a14..37a298fb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,5 @@ ignore = [ "D400", # first line should end with a period [Bug: doesn’t work with single-line docstrings] "D401", # First line should be in imperative mood; try rephrasing ] -per-file-ignores."*/__init__.py" = ["F401"] per-file-ignores."tests/*" = ["D"] pydocstyle.convention = "numpy" \ No newline at end of file diff --git a/tests/preprocessing/resize_images_tests.py b/tests/preprocessing/resize_images_tests.py index 5d237b83b..2cf399600 100644 --- a/tests/preprocessing/resize_images_tests.py +++ b/tests/preprocessing/resize_images_tests.py @@ -132,7 +132,7 @@ def test_resize_volumes(self): ), ignore_index=True, ) - op_fname = f"im_c00{c}_z000_t005_p007_3.3-0.8-1.0.npy" + op_fname = "im_c00{c}_z000_t005_p007_3.3-0.8-1.0.npy".format(c=c) exp_meta_dict.append( { "time_idx": self.time_idx, @@ -168,7 +168,7 @@ def test_resize_volumes(self): exp_meta_dict = [] for c in channel_ids: for s in [0, 2]: - op_fname = f"im_c00{c}_z00{s}_t005_p007_3.3-0.8-1.0.npy" + op_fname = "im_c00{c}_z00{s}_t005_p007_3.3-0.8-1.0.npy".format(c=c, s=s) exp_meta_dict.append( { "time_idx": self.time_idx, diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 9fbdfab9c..b0e4a85b6 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -396,7 +396,20 @@ class CachedConcatDataModule(LightningDataModule): Concatenates multiple data modules with support for distributed sampling and caching optimizations for large-scale ML training. - # TODO: MANUAL_REVIEW - Verify caching behavior and memory usage + + Parameters + ---------- + data_modules : Sequence[LightningDataModule] + Data modules to concatenate. + + Raises + ------ + ValueError + If inconsistent number of workers or batch size across data modules. + NotImplementedError + If stage other than "fit" is requested. + ValueError + If patches per stack are inconsistent across data modules. """ def __init__(self, data_modules: Sequence[LightningDataModule]): diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index e05eea269..1559e06f8 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -3,8 +3,6 @@ import os import re import tempfile - -# from collections.abc import Callable, Sequence from pathlib import Path from typing import Callable, Literal, Sequence @@ -35,8 +33,15 @@ def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ Ensure channel argument is a list of strings. - :param Union[str, Sequence[str]] str_or_seq: channel name or list of channel names - :return list[str]: list of channel names + Parameters + ---------- + str_or_seq : str | Sequence[str] + Channel name or list of channel names + + Returns + ------- + list[str] + List of channel names """ if isinstance(str_or_seq, str): return [str_or_seq] @@ -81,10 +86,17 @@ def _search_int_in_str(pattern: str, file_name: str) -> str: def _collate_samples(batch: Sequence[Sample]) -> Sample: """Collate samples into a batch sample. - :param Sequence[Sample] batch: a sequence of dictionaries, + Parameters + ---------- + batch : Sequence[Sample] + A sequence of dictionaries, where each key may point to a value of a single tensor or a list of tensors, as is the case with ``train_patches_per_stack > 1``. - :return Sample: Batch sample (dictionary of tensors) + + Returns + ------- + Sample + Batch sample (dictionary of tensors) """ collated: Sample = {} for key in batch[0].keys(): From cf230555f2e13c9699f9e9f2a8d1799a989105be Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Fri, 12 Sep 2025 10:47:43 -0700 Subject: [PATCH 10/13] undid resize_images_tests.py::test_resize_volumes --- tests/preprocessing/resize_images_tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/preprocessing/resize_images_tests.py b/tests/preprocessing/resize_images_tests.py index 2cf399600..5b41d0390 100644 --- a/tests/preprocessing/resize_images_tests.py +++ b/tests/preprocessing/resize_images_tests.py @@ -132,7 +132,7 @@ def test_resize_volumes(self): ), ignore_index=True, ) - op_fname = "im_c00{c}_z000_t005_p007_3.3-0.8-1.0.npy".format(c=c) + op_fname = "im_c00{c}_z000_t005_p007_3.3-0.8-1.0.npy".format(c) exp_meta_dict.append( { "time_idx": self.time_idx, @@ -168,7 +168,7 @@ def test_resize_volumes(self): exp_meta_dict = [] for c in channel_ids: for s in [0, 2]: - op_fname = "im_c00{c}_z00{s}_t005_p007_3.3-0.8-1.0.npy".format(c=c, s=s) + op_fname = "im_c00{c}_z00{s}_t005_p007_3.3-0.8-1.0.npy".format(c, s) exp_meta_dict.append( { "time_idx": self.time_idx, From b6635d4f96894206e8ac7f6f8ee3ee2e2ccd8710 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Wed, 17 Sep 2025 17:38:36 -0700 Subject: [PATCH 11/13] Resolve merge conflicts for docstring standardization - Fixed conflicts in data modules (combined.py, triplet.py, hcs.py) - Resolved classification.py parameter conflicts - Updated transforms __init__.py imports - Cleaned up _transforms.py duplicate classes --- viscy/data/cell_classification.py | 23 +- viscy/data/combined.py | 48 +++- viscy/data/hcs.py | 22 +- viscy/data/triplet.py | 313 ++++++++++++++++++------- viscy/representation/classification.py | 5 +- viscy/scripts/bench_augmentations.py | 44 ++-- viscy/scripts/profiling.py | 137 +++++------ viscy/transforms/__init__.py | 57 ++++- viscy/transforms/_transforms.py | 57 ----- 9 files changed, 435 insertions(+), 271 deletions(-) diff --git a/viscy/data/cell_classification.py b/viscy/data/cell_classification.py index 7b42cd82b..a30b55aff 100644 --- a/viscy/data/cell_classification.py +++ b/viscy/data/cell_classification.py @@ -46,6 +46,7 @@ def __init__( transform: Callable | None, initial_yx_patch_size: tuple[int, int], return_indices: bool = False, + label_column: str = "infection_state", ): self.plate = plate self.z_range = z_range @@ -65,6 +66,7 @@ def __init__( annotation["y"].between(*y_range, inclusive="neither") & annotation["x"].between(*x_range, inclusive="neither") ] + self.label_column = label_column def __len__(self): """Return the number of samples in the dataset.""" @@ -103,7 +105,7 @@ def __getitem__( img = (image - norm_meta["mean"]) / norm_meta["std"] if self.transform is not None: img = self.transform(img) - label = torch.tensor(row["infection_state"]).float()[None] + label = torch.tensor(row[self.label_column]).float()[None] if self.return_indices: return img, label, row[INDEX_COLUMNS].to_dict() else: @@ -149,12 +151,13 @@ def __init__( val_fovs: list[str] | None, channel_name: str, z_range: tuple[int, int], - train_exlude_timepoints: list[int], + train_exclude_timepoints: list[int], train_transforms: list[Callable] | None, val_transforms: list[Callable] | None, initial_yx_patch_size: tuple[int, int], batch_size: int, num_workers: int, + label_column: str = "infection_state", ): super().__init__() self.image_path = image_path @@ -162,12 +165,13 @@ def __init__( self.val_fovs = val_fovs self.channel_name = channel_name self.z_range = z_range - self.train_exlude_timepoints = train_exlude_timepoints + self.train_exclude_timepoints = train_exclude_timepoints self.train_transform = Compose(train_transforms) self.val_transform = Compose(val_transforms) self.initial_yx_patch_size = initial_yx_patch_size self.batch_size = batch_size self.num_workers = num_workers + self.label_column = label_column def _subset( self, @@ -189,6 +193,7 @@ def _subset( transform=transform, initial_yx_patch_size=self.initial_yx_patch_size, return_indices=return_indices, + label_column=self.label_column, ) def setup(self, stage=None) -> None: @@ -208,8 +213,16 @@ def setup(self, stage=None) -> None: If stage is unknown. """ plate = open_ome_zarr(self.image_path) - all_fovs = ["/" + name for (name, _) in plate.positions()] annotation = pd.read_csv(self.annotation_path) + all_fovs = [name for (name, _) in plate.positions()] + if annotation["fov_name"].iloc[0].startswith("/"): + all_fovs = ["/" + name for name in all_fovs] + if all_fovs[0].startswith("/"): + if not self.val_fovs[0].startswith("/"): + self.val_fovs = ["/" + name for name in self.val_fovs] + else: + if self.val_fovs[0].startswith("/"): + self.val_fovs = [name[1:] for name in self.val_fovs] for column in ("t", "y", "x"): annotation[column] = annotation[column].astype(int) if stage in (None, "fit", "validate"): @@ -219,7 +232,7 @@ def setup(self, stage=None) -> None: annotation, train_fovs, transform=self.train_transform, - exclude_timepoints=self.train_exlude_timepoints, + exclude_timepoints=self.train_exclude_timepoints, ) self.val_dataset = self._subset( plate, diff --git a/viscy/data/combined.py b/viscy/data/combined.py index b0e4a85b6..2d2f0c9c8 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -189,7 +189,7 @@ def _get_sample_indices(self, idx: int) -> tuple[int, int]: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return dataset_idx, sample_idx - def __getitems__(self, indices: list[int]) -> list: + def __getitems__(self, indices: list[int]) -> list[dict[str, torch.Tensor]]: """Retrieve multiple items by indices with batched dataset access. Groups indices by source dataset and performs batched retrieval @@ -202,7 +202,7 @@ def __getitems__(self, indices: list[int]) -> list: Returns ------- - list + list[dict[str, torch.Tensor]] Samples from all requested indices, maintaining order. """ grouped_indices = defaultdict(list) @@ -210,11 +210,14 @@ def __getitems__(self, indices: list[int]) -> list: dataset_idx, sample_indices = self._get_sample_indices(idx) grouped_indices[dataset_idx].append(sample_indices) _logger.debug(f"Grouped indices: {grouped_indices}") - sub_batches = [] + + micro_batches = [] for dataset_idx, sample_indices in grouped_indices.items(): - sub_batch = self.datasets[dataset_idx].__getitems__(sample_indices) - sub_batches.extend(sub_batch) - return sub_batches + micro_batch = self.datasets[dataset_idx].__getitems__(sample_indices) + micro_batch["_dataset_idx"] = dataset_idx + micro_batches.append(micro_batch) + + return micro_batches class ConcatDataModule(LightningDataModule): @@ -369,6 +372,7 @@ def train_dataloader(self) -> ThreadDataLoader: batch_size=self.batch_size, shuffle=True, drop_last=True, + collate_fn=lambda x: x, **self._dataloader_kwargs(), ) @@ -387,9 +391,41 @@ def val_dataloader(self) -> ThreadDataLoader: batch_size=self.batch_size, shuffle=False, drop_last=False, + collate_fn=lambda x: x, **self._dataloader_kwargs(), ) + def on_after_batch_transfer(self, batch, dataloader_idx: int): + """Apply GPU transforms from constituent data modules to micro-batches.""" + processed_micro_batches = [] + for micro_batch in batch: + dataset_idx = micro_batch.pop("_dataset_idx") + dm = self.data_modules[dataset_idx] + if hasattr(dm, "on_after_batch_transfer"): + processed_micro_batch = dm.on_after_batch_transfer( + micro_batch, dataloader_idx + ) + else: + processed_micro_batch = micro_batch + processed_micro_batches.append(processed_micro_batch) + combined_batch = {} + for key in processed_micro_batches[0].keys(): + if isinstance(processed_micro_batches[0][key], list): + combined_batch[key] = [] + for micro_batch in processed_micro_batches: + if key in micro_batch: + combined_batch[key].extend(micro_batch[key]) + else: + tensors_to_concat = [ + micro_batch[key] + for micro_batch in processed_micro_batches + if key in micro_batch + ] + if tensors_to_concat: + combined_batch[key] = torch.cat(tensors_to_concat, dim=0) + + return combined_batch + class CachedConcatDataModule(LightningDataModule): """Cached concatenated data module for distributed training. diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 1559e06f8..078234db1 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -735,22 +735,24 @@ def _fit_transform(self) -> tuple[Compose, Compose]: Training and validation transform compositions """ # TODO: These have a fixed order for now... () - final_crop = [ - CenterSpatialCropd( - keys=self.source_channel + self.target_channel, - roi_size=( - self.z_window_size, - self.yx_patch_size[0], - self.yx_patch_size[1], - ), - ) - ] + final_crop = [self._final_crop()] train_transform = Compose( self.normalizations + self._train_transform() + final_crop ) val_transform = Compose(self.normalizations + final_crop) return train_transform, val_transform + def _final_crop(self) -> CenterSpatialCropd: + """Setup final cropping: center crop to the target size.""" + return CenterSpatialCropd( + keys=self.source_channel + self.target_channel, + roi_size=( + self.z_window_size, + self.yx_patch_size[0], + self.yx_patch_size[1], + ), + ) + def _train_transform(self) -> list[Callable]: """Set up training augmentations. diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index b899ac15e..a22a4b251 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -1,22 +1,24 @@ import logging +import os +import warnings from collections.abc import Sequence from pathlib import Path from typing import Literal -import numpy as np import pandas as pd import tensorstore as ts import torch from iohub.ngff import ImageArray, Position, open_ome_zarr -from monai.data import ThreadDataLoader +from monai.data.thread_buffer import ThreadDataLoader from monai.data.utils import collate_meta_tensor -from monai.transforms import Compose, MapTransform, ToDeviced +from monai.transforms import Compose, MapTransform from torch import Tensor from torch.utils.data import Dataset from viscy.data.hcs import HCSDataModule, _read_norm_meta from viscy.data.select import _filter_fovs, _filter_wells -from viscy.data.typing import DictTransform, NormMeta, TripletSample +from viscy.data.typing import DictTransform, NormMeta +from viscy.transforms import BatchedCenterSpatialCropd _logger = logging.getLogger("lightning.pytorch") @@ -46,13 +48,10 @@ def _scatter_channels( def _gather_channels( - patch_channels: list[dict[str, Tensor | NormMeta]], + patch_channels: dict[str, Tensor | NormMeta], ) -> list[Tensor]: - samples = [] - for sample in patch_channels: - sample.pop("norm_meta", None) - samples.append(torch.cat(list(sample.values()), dim=0)) - return samples + patch_channels.pop("norm_meta", None) + return torch.cat(list(patch_channels.values()), dim=1) def _transform_channel_wise( @@ -118,25 +117,55 @@ def __init__( channel_names: list[str], initial_yx_patch_size: tuple[int, int], z_range: slice, - anchor_transform: DictTransform | None = None, - positive_transform: DictTransform | None = None, - negative_transform: DictTransform | None = None, fit: bool = True, predict_cells: bool = False, include_fov_names: list[str] | None = None, include_track_ids: list[int] | None = None, time_interval: Literal["any"] | int = "any", return_negative: bool = True, + cache_pool_bytes: int = 0, ) -> None: + """Dataset for triplet sampling of cells based on tracking. + + Parameters + ---------- + positions : list[Position] + OME-Zarr images with consistent channel order + tracks_tables : list[pd.DataFrame] + Data frames containing ultrack results + channel_names : list[str] + Input channel names + initial_yx_patch_size : tuple[int, int] + YX size of the initially sampled image patch + z_range : slice + Range of Z-slices + fit : bool, optional + Fitting mode in which the full triplet will be sampled, + only sample anchor if ``False``, by default True + predict_cells : bool, optional + Only predict on selected cells, by default False + include_fov_names : list[str] | None, optional + Only predict on selected FOVs, by default None + include_track_ids : list[int] | None, optional + Only predict on selected track IDs, by default None + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, + by default "any" + (sample negative from another track any time point + and use the augmented anchor patch as positive) + return_negative : bool, optional + Whether to return the negative sample during the fit stage + (can be set to False when using a loss function like NT-Xent), + by default True + cache_pool_bytes : int, optional + Size of the tensorstore cache pool in bytes, by default 0 + """ self.positions = positions self.channel_names = channel_names self.channel_indices = [ positions[0].get_channel_index(ch) for ch in channel_names ] self.z_range = z_range - self.anchor_transform = anchor_transform - self.positive_transform = positive_transform - self.negative_transform = negative_transform self.fit = fit self.yx_patch_size = initial_yx_patch_size self.predict_cells = predict_cells @@ -149,14 +178,44 @@ def __init__( ) self.valid_anchors = self._filter_anchors(self.tracks) self.return_negative = return_negative + self._setup_tensorstore_context(cache_pool_bytes) + + def _setup_tensorstore_context(self, cache_pool_bytes: int): + """Configure tensorstore context with CPU limits based on SLURM environment.""" + cpus_per_task = os.environ.get("SLURM_CPUS_PER_TASK") + if cpus_per_task is not None: + cpus_per_task = int(cpus_per_task) + else: + cpus_per_task = os.cpu_count() or 4 + self.tensorstore_context = ts.Context( + { + "data_copy_concurrency": {"limit": cpus_per_task}, + "cache_pool": {"total_bytes_limit": cache_pool_bytes}, + } + ) + self._tensorstores = {} + + def _get_tensorstore(self, position: Position) -> ts.TensorStore: + """Get cached tensorstore object or create and cache new one.""" + fov_name = position.zgroup.name + if fov_name not in self._tensorstores: + self._tensorstores[fov_name] = position["0"].tensorstore( + context=self.tensorstore_context, + # assume immutable data to reduce metadata access + recheck_cached_data="open", + ) + return self._tensorstores[fov_name] def _filter_tracks(self, tracks_tables: list[pd.DataFrame]) -> pd.DataFrame: - """Exclude tracks that are too close to the border or do not have the next time point. + """ + + Exclude tracks that are too close to the border or do not have the next time point. Parameters ---------- tracks_tables : list[pd.DataFrame] - List of tracks_tables returned by TripletDataModule._align_tracks_tables_with_positions + List of tracks_tables returned by + TripletDataModule._align_tracks_tables_with_positions Returns ------- @@ -252,16 +311,19 @@ def _sample_negative(self, anchor_row: pd.Series) -> pd.Series: return candidates.sample(n=1).iloc[0] def _sample_negatives(self, anchor_rows: pd.DataFrame) -> pd.DataFrame: - return pd.concat( - [self._sample_negative(row) for _, row in anchor_rows.iterrows()], - axis=1, - ) + negative_samples = [ + self._sample_negative(row) for _, row in anchor_rows.iterrows() + ] + return pd.DataFrame(negative_samples).reset_index(drop=True) def _slice_patch( self, track_row: pd.Series ) -> tuple[ts.TensorStore, NormMeta | None]: position: Position = track_row["position"] - image = position["0"].tensorstore() + + # Get cached tensorstore object using FOV name + image = self._get_tensorstore(position) + time = track_row["t"] y_center = track_row["y"] x_center = track_row["x"] @@ -278,18 +340,18 @@ def _slice_patch( def _slice_patches(self, track_rows: pd.DataFrame): patches = [] norms = [] - with ts.Batch() as batch: - for _, row in track_rows.iterrows(): - patch, norm = self._slice_patch(row) - patches.append(patch.read(batch=batch)) - norms.append(norm) - results = [p.result() for p in patches] - return torch.from_numpy(np.stack(results, axis=0)), norms - - def __getitems__(self, indices: list[int]) -> list[TripletSample]: + for _, row in track_rows.iterrows(): + patch, norm = self._slice_patch(row) + patches.append(patch) + norms.append(norm) + results = ts.stack([p.translate_to[0] for p in patches]).read().result() + return torch.from_numpy(results), norms + + def __getitems__(self, indices: list[int]) -> dict[str, torch.Tensor]: """Get batched triplet samples for efficient data loading.""" anchor_rows = self.valid_anchors.iloc[indices] anchor_patches, anchor_norms = self._slice_patches(anchor_rows) + sample = {"anchor": anchor_patches, "anchor_norm_meta": anchor_norms} if self.fit: if self.time_interval == "any": positive_patches = anchor_patches.clone() @@ -297,51 +359,27 @@ def __getitems__(self, indices: list[int]) -> list[TripletSample]: else: positive_rows = self._sample_positives(anchor_rows) positive_patches, positive_norms = self._slice_patches(positive_rows) - if self.positive_transform: - positive_patches = _transform_channel_wise( - transform=self.positive_transform, - channel_names=self.channel_names, - patch=positive_patches, - norm_meta=positive_norms, - ) + + sample["positive"] = positive_patches + sample["positive_norm_meta"] = positive_norms if self.return_negative: negative_rows = self._sample_negatives(anchor_rows) negative_patches, negative_norms = self._slice_patches(negative_rows) - if self.negative_transform: - negative_patches = _transform_channel_wise( - transform=self.negative_transform, - channel_names=self.channel_names, - patch=negative_patches, - norm_meta=negative_norms, - ) - if self.anchor_transform: - anchor_patches = _transform_channel_wise( - transform=self.anchor_transform, - channel_names=self.channel_names, - patch=anchor_patches, - norm_meta=anchor_norms, - ) - samples: list[TripletSample] = [ - {"anchor": anchor_patch} for anchor_patch in anchor_patches - ] - if self.fit: - for sample, positive_patch in zip(samples, positive_patches): - sample["positive"] = positive_patch - if self.return_negative: - for sample, negative_patch in zip(samples, negative_patches): - sample["negative"] = negative_patch + sample["negative"] = negative_patches + sample["negative_norm_meta"] = negative_norms else: - for sample, (_, anchor_row) in zip(samples, anchor_rows.iterrows()): - # For new predictions, ensure all INDEX_COLUMNS are included + indices_list = [] + for _, anchor_row in anchor_rows.iterrows(): index_dict = {} for col in INDEX_COLUMNS: if col in anchor_row.index: index_dict[col] = anchor_row[col] elif col not in ["y", "x", "z"]: - # Skip y and x for legacy data - they weren't part of INDEX_COLUMNS raise KeyError(f"Required column '{col}' not found in data") - sample["index"] = index_dict - return samples + indices_list.append(index_dict) + sample["index"] = indices_list + + return sample class TripletDataModule(HCSDataModule): @@ -415,7 +453,7 @@ def __init__( final_yx_patch_size: tuple[int, int] = (224, 224), split_ratio: float = 0.8, batch_size: int = 16, - num_workers: int = 8, + num_workers: int = 1, normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], caching: bool = False, @@ -430,7 +468,71 @@ def __init__( prefetch_factor: int | None = None, pin_memory: bool = False, z_window_size: int | None = None, + cache_pool_bytes: int = 0, ): + """Lightning data module for triplet sampling of patches. + + Parameters + ---------- + data_path : str + Image dataset path + tracks_path : str + Tracks labels dataset path + source_channel : str | Sequence[str] + List of input channel names + z_range : tuple[int, int] + Range of valid z-slices + initial_yx_patch_size : tuple[int, int], optional + XY size of the initially sampled image patch, by default (512, 512) + final_yx_patch_size : tuple[int, int], optional + Output patch size, by default (224, 224) + split_ratio : float, optional + Ratio of training samples, by default 0.8 + batch_size : int, optional + Batch size, by default 16 + num_workers : int, optional + Number of thread workers. + Set to 0 to disable threading. Using more than 1 is not recommended. + by default 1 + normalizations : list[MapTransform], optional + Normalization transforms, by default [] + augmentations : list[MapTransform], optional + Augmentation transforms, by default [] + caching : bool, optional + Whether to cache the dataset, by default False + fit_include_wells : list[str], optional + Only include these wells for fitting, by default None + fit_exclude_fovs : list[str], optional + Exclude these FOVs for fitting, by default None + predict_cells : bool, optional + Only predict for selected cells, by default False + include_fov_names : list[str] | None, optional + Only predict for selected FOVs, by default None + include_track_ids : list[int] | None, optional + Only predict for selected tracks, by default None + time_interval : Literal["any"] | int, optional + Future time interval to sample positive and anchor from, + "any" means sampling negative from another track any time point + and using the augmented anchor patch as positive), by default "any" + return_negative : bool, optional + Whether to return the negative sample during the fit stage + (can be set to False when using a loss function like NT-Xent), + by default True + persistent_workers : bool, optional + Whether to keep worker processes alive between iterations, by default False + prefetch_factor : int | None, optional + Number of batches loaded in advance by each worker, by default None + pin_memory : bool, optional + Whether to pin memory in CPU for faster GPU transfer, by default False + z_window_size : int, optional + Size of the final Z window, by default None (inferred from z_range) + cache_pool_bytes : int, optional + Size of the per-process tensorstore cache pool in bytes, by default 0 + """ + if num_workers > 1: + warnings.warn( + "Using more than 1 thread worker will likely degrade performance." + ) super().__init__( data_path=data_path, source_channel=source_channel, @@ -458,6 +560,13 @@ def __init__( self.include_track_ids = include_track_ids self.time_interval = time_interval self.return_negative = return_negative + self._cache_pool_bytes = cache_pool_bytes + self._augmentation_transform = Compose( + self.normalizations + self.augmentations + [self._final_crop()] + ) + self._no_augmentation_transform = Compose( + self.normalizations + [self._final_crop()] + ) def _align_tracks_tables_with_positions( self, @@ -491,19 +600,10 @@ def _base_dataset_settings(self) -> dict: "channel_names": self.source_channel, "z_range": self.z_range, "time_interval": self.time_interval, + "cache_pool_bytes": self._cache_pool_bytes, } - def _update_to_device_transform(self): - """Make sure that GPU transforms are set to the current device.""" - for transform in self.normalizations + self.augmentations: - if isinstance(transform, ToDeviced): - transform.converter.device = torch.device( - f"cuda:{torch.cuda.current_device()}" - ) - def _setup_fit(self, dataset_settings: dict): - self._update_to_device_transform() - augment_transform, no_aug_transform = self._fit_transform() positions, tracks_tables = self._align_tracks_tables_with_positions() shuffled_indices = self._set_fit_global_state(len(positions)) positions = [positions[i] for i in shuffled_indices] @@ -516,30 +616,18 @@ def _setup_fit(self, dataset_settings: dict): val_tracks_tables = tracks_tables[num_train_fovs:] _logger.debug(f"Number of training FOVs: {len(train_positions)}") _logger.debug(f"Number of validation FOVs: {len(val_positions)}") - anchor_transform = ( - no_aug_transform - if (self.time_interval == "any" or self.time_interval == 0) - else augment_transform - ) self.train_dataset = TripletDataset( positions=train_positions, tracks_tables=train_tracks_tables, initial_yx_patch_size=self.initial_yx_patch_size, - anchor_transform=anchor_transform, - positive_transform=augment_transform, - negative_transform=augment_transform, fit=True, return_negative=self.return_negative, **dataset_settings, ) - self.val_dataset = TripletDataset( positions=val_positions, tracks_tables=val_tracks_tables, initial_yx_patch_size=self.initial_yx_patch_size, - anchor_transform=anchor_transform, - positive_transform=augment_transform, - negative_transform=augment_transform, fit=True, return_negative=self.return_negative, **dataset_settings, @@ -552,7 +640,6 @@ def _setup_predict(self, dataset_settings: dict): positions=positions, tracks_tables=tracks_tables, initial_yx_patch_size=self.initial_yx_patch_size, - anchor_transform=Compose(self.normalizations), fit=False, predict_cells=self.predict_cells, include_fov_names=self.include_fov_names, @@ -581,6 +668,7 @@ def train_dataloader(self) -> ThreadDataLoader: persistent_workers=self.persistent_workers, drop_last=True, pin_memory=self.pin_memory, + collate_fn=lambda x: x, ) def val_dataloader(self) -> ThreadDataLoader: @@ -601,6 +689,7 @@ def val_dataloader(self) -> ThreadDataLoader: persistent_workers=self.persistent_workers, drop_last=False, pin_memory=self.pin_memory, + collate_fn=lambda x: x, ) def predict_dataloader(self) -> ThreadDataLoader: @@ -621,4 +710,46 @@ def predict_dataloader(self) -> ThreadDataLoader: persistent_workers=self.persistent_workers, drop_last=False, pin_memory=self.pin_memory, + collate_fn=lambda x: x, ) + + def _final_crop(self) -> BatchedCenterSpatialCropd: + """Setup final cropping: center crop to the target size.""" + return BatchedCenterSpatialCropd( + keys=self.source_channel, + roi_size=( + self.z_window_size, + self.yx_patch_size[0], + self.yx_patch_size[1], + ), + ) + + def _find_transform(self, key: str): + if self.trainer: + if self.trainer.predicting: + return self._no_augmentation_transform + # NOTE: for backwards compatibility + if key == "anchor" and self.time_interval in ("any", 0): + return self._no_augmentation_transform + return self._augmentation_transform + + def on_after_batch_transfer(self, batch, dataloader_idx: int): + """Apply transforms after transferring to device.""" + if isinstance(batch, Tensor): + # example array + return batch + for key in ["anchor", "positive", "negative"]: + if key in batch: + norm_meta_key = f"{key}_norm_meta" + norm_meta = batch.get(norm_meta_key) + transformed_patches = _transform_channel_wise( + transform=self._find_transform(key), + channel_names=self.source_channel, + patch=batch[key], + norm_meta=norm_meta, + ) + batch[key] = transformed_patches + if norm_meta_key in batch: + del batch[norm_meta_key] + + return batch diff --git a/viscy/representation/classification.py b/viscy/representation/classification.py index cb686f3b1..718866124 100644 --- a/viscy/representation/classification.py +++ b/viscy/representation/classification.py @@ -74,6 +74,8 @@ class ClassificationModule(LightningModule): Learning rate. loss : nn.Module | None Loss function. By default, BCEWithLogitsLoss with positive weight of 1.0. + example_input_array_shape : tuple[int, ...] + Shape of the example input array. """ def __init__( @@ -81,6 +83,7 @@ def __init__( encoder: ContrastiveEncoder, lr: float | None, loss: nn.Module | None = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(1.0)), + example_input_array_shape: tuple[int, ...] = (2, 1, 15, 160, 160), ) -> None: super().__init__() self.stem = encoder.stem @@ -88,7 +91,7 @@ def __init__( self.backbone.head.fc = nn.Linear(768, 1) self.loss = loss self.lr = lr - self.example_input_array = torch.rand(2, 1, 15, 160, 160) + self.example_input_array = torch.rand(example_input_array_shape) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through stem and backbone for classification. diff --git a/viscy/scripts/bench_augmentations.py b/viscy/scripts/bench_augmentations.py index a5e358cf5..620314a15 100644 --- a/viscy/scripts/bench_augmentations.py +++ b/viscy/scripts/bench_augmentations.py @@ -1,64 +1,62 @@ # %% import torch -from kornia.augmentation import RandomAffine3D from lightning.pytorch import seed_everything from monai.data.meta_obj import set_track_meta -from monai.transforms import RandAffine +from monai.transforms import RandSpatialCrop from torch.utils.benchmark import Timer +from viscy.transforms import BatchedRandSpatialCrop + seed_everything(42) # %% x = torch.rand(32, 2, 15, 512, 512, device="cuda") # %% -monai_transform = RandAffine( - prob=1.0, - rotate_range=(torch.pi, 0, 0), - scale_range=(0.2, 0.3, 0.3), - padding_mode="zeros", -) +roi_size = [8, 256, 256] -kornia_transform = RandomAffine3D( - degrees=(360.0, 0.0, 0.0), - scale=((0.8, 1.2), (0.7, 1.3), (0.7, 1.3)), - p=1.0, +monai_transform = RandSpatialCrop( + roi_size=roi_size, random_center=True, random_size=False ) +batched_transform = BatchedRandSpatialCrop(roi_size=roi_size, random_center=True) # %% def bench_monai(x): set_track_meta(False) with torch.inference_mode(): + results = [] for sample in x: - _ = monai_transform(sample) + cropped = monai_transform(sample) + results.append(cropped) + return torch.stack(results) -def bench_kornia(x): +def bench_batched(x): with torch.inference_mode(): - _ = kornia_transform(x) + return batched_transform(x) # %% globals_injection = { "x": x, - "monai_transform": monai_transform, - "kornia_transform": kornia_transform, + "bench_monai": bench_monai, + "bench_batched": bench_batched, } monai_timer = Timer( stmt="bench_monai(x)", globals=globals_injection, - label="monai", + label="MONAI (loop)", setup="from __main__ import bench_monai", # num_threads=16, ) -kornia_timer = Timer( - stmt="bench_kornia(x)", +batched_timer = Timer( + stmt="bench_batched(x)", globals=globals_injection, - label="kornia", - setup="from __main__ import bench_kornia", + label="Batched (gather)", + setup="from __main__ import bench_batched", # num_threads=16, ) @@ -66,6 +64,6 @@ def bench_kornia(x): monai_timer.timeit(10) # %% -kornia_timer.timeit(10) +batched_timer.timeit(10) # %% diff --git a/viscy/scripts/profiling.py b/viscy/scripts/profiling.py index 5b978a09f..4cc3e6c6e 100644 --- a/viscy/scripts/profiling.py +++ b/viscy/scripts/profiling.py @@ -1,55 +1,45 @@ # script to profile dataloading # use with a sampling profiler like py-spy -from monai.transforms import ( - Decollated, - RandAdjustContrastd, - RandGaussianSmoothd, - RandScaleIntensityd, - ToDeviced, -) -from pytorch_metric_learning.losses import NTXentLoss + + +import logging + +import torch +from lightning.pytorch import LightningModule, Trainer from viscy.data.combined import BatchedConcatDataModule from viscy.data.triplet import TripletDataModule -from viscy.representation.engine import ContrastiveEncoder, ContrastiveModule from viscy.transforms import ( - NormalizeSampled, -) -from viscy.transforms._transforms import ( + BatchedCenterSpatialCropd, + BatchedRandAdjustContrastd, BatchedRandAffined, + BatchedRandGaussianNoised, + BatchedRandGaussianSmoothd, + BatchedRandScaleIntensityd, BatchedScaleIntensityRangePercentilesd, - RandGaussianNoiseTensord, + NormalizeSampled, ) +_logger = logging.getLogger(__name__) -def model( - input_channel_number: int = 1, - z_stack_depth: int = 30, - patch_size: int = 192, - temperature: float = 0.5, -): - return ContrastiveModule( - encoder=ContrastiveEncoder( - backbone="convnext_tiny", - in_channels=input_channel_number, - in_stack_depth=z_stack_depth, - stem_kernel_size=(5, 4, 4), - embedding_dim=768, - projection_dim=32, - drop_path_rate=0.0, - ), - loss_function=NTXentLoss(temperature=temperature), - lr=0.00002, - log_batches_per_epoch=3, - log_samples_per_batch=3, - example_input_array_shape=[ - 1, - input_channel_number, - z_stack_depth, - patch_size, - patch_size, - ], - ) + +class DummyModel(LightningModule): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.zeros(1, requires_grad=True)) + + def training_step(self, batch, batch_idx): + img = batch["anchor"] + _logger.info(img.shape) + return (img * self.a).mean() + + def validation_step(self, batch, batch_idx): + img = batch["anchor"] + _logger.info(img.shape) + return (img * self.a).mean() + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters()) def channel_augmentations(processing_channel: str): @@ -61,25 +51,25 @@ def channel_augmentations(processing_channel: str): rotate_range=[1.0, 0.0, 0.0], shear_range=(0.2, 0.2, 0.0, 0.2, 0.0, 0.2), ), - Decollated(keys=[processing_channel]), - RandAdjustContrastd( + BatchedCenterSpatialCropd(keys=[processing_channel], roi_size=(32, 192, 192)), + BatchedRandAdjustContrastd( keys=[processing_channel], prob=0.5, - gamma=[0.8, 1.2], + gamma=(0.8, 1.2), ), - RandScaleIntensityd( + BatchedRandScaleIntensityd( keys=[processing_channel], prob=0.5, factors=0.5, ), - RandGaussianSmoothd( + BatchedRandGaussianSmoothd( keys=[processing_channel], prob=0.5, - sigma_x=[0.25, 0.75], - sigma_y=[0.25, 0.75], - sigma_z=[0.0, 0.0], + sigma_x=(0.25, 0.75), + sigma_y=(0.25, 0.75), + sigma_z=(0.0, 0.0), ), - RandGaussianNoiseTensord( + BatchedRandGaussianNoised( keys=[processing_channel], prob=0.5, mean=0.0, @@ -103,7 +93,6 @@ def channel_normalization( ] elif fl_channel: return [ - ToDeviced(keys=[fl_channel], device="cuda"), BatchedScaleIntensityRangePercentilesd( keys=[fl_channel], lower=50, @@ -117,15 +106,19 @@ def channel_normalization( if __name__ == "__main__": + num_workers = 1 + batch_size = 128 + persistent_workers = True + cache_pool_bytes = 32 << 30 dm1 = TripletDataModule( data_path="/hpc/projects/organelle_phenotyping/datasets/organelle/SEC61B/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV_2.zarr", - tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr", + tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr", source_channel=["raw GFP EX488 EM525-45"], z_range=[5, 35], initial_yx_patch_size=(384, 384), final_yx_patch_size=(192, 192), - batch_size=16, - num_workers=4, + batch_size=batch_size, + num_workers=num_workers, time_interval=1, augmentations=channel_augmentations("raw GFP EX488 EM525-45"), normalizations=channel_normalization( @@ -133,37 +126,27 @@ def channel_normalization( ), fit_include_wells=["B/3", "B/4", "C/3", "C/4"], return_negative=False, + persistent_workers=persistent_workers, + cache_pool_bytes=cache_pool_bytes, ) dm2 = TripletDataModule( data_path="/hpc/projects/organelle_phenotyping/datasets/organelle/SEC61B/2024_10_16_A549_SEC61_ZIKV_DENV/2024_10_16_A549_SEC61_ZIKV_DENV_2.zarr", - tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr", - source_channel=["raw mCherry EX561 EM600-37"], + tracks_path="/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_10_16_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_10_16_A549_SEC61_ZIKV_DENV_cropped.zarr", + source_channel=["Phase3D"], z_range=[5, 35], initial_yx_patch_size=(384, 384), final_yx_patch_size=(192, 192), - batch_size=16, - num_workers=4, + batch_size=batch_size, + num_workers=num_workers, time_interval=1, - augmentations=channel_augmentations("raw mCherry EX561 EM600-37"), - normalizations=channel_normalization( - phase_channel=None, fl_channel="raw mCherry EX561 EM600-37" - ), + augmentations=channel_augmentations("Phase3D"), + normalizations=channel_normalization(phase_channel="Phase3D", fl_channel=None), fit_include_wells=["B/3", "B/4", "C/3", "C/4"], return_negative=False, + persistent_workers=persistent_workers, + cache_pool_bytes=cache_pool_bytes, ) dm = BatchedConcatDataModule(data_modules=[dm1, dm2]) - dm.setup("fit") - - print(len(dm1.train_dataset), len(dm2.train_dataset), len(dm.train_dataset)) - n = 1 - - print("Training batches:") - for i, batch in enumerate(dm.train_dataloader()): - print(i, batch["anchor"].shape, batch["positive"].device) - if i == n - 1: - break - print("Validation batches:") - for i, batch in enumerate(dm.val_dataloader()): - print(i, batch["anchor"].shape, batch["positive"].device) - if i == n - 1: - break + model = DummyModel() + trainer = Trainer(max_epochs=4, limit_train_batches=8, limit_val_batches=8) + trainer.fit(model, dm) diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index 6ca88eb4b..cccb749de 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -1,5 +1,27 @@ """VisCy transform package for data preprocessing and augmentation.""" +from viscy.transforms._adjust_contrast import ( + BatchedRandAdjustContrast, + BatchedRandAdjustContrastd, +) +from viscy.transforms._crop import ( + BatchedCenterSpatialCrop, + BatchedCenterSpatialCropd, + BatchedRandSpatialCrop, + BatchedRandSpatialCropd, +) +from viscy.transforms._decollate import Decollate +from viscy.transforms._flip import BatchedRandFlip, BatchedRandFlipd +from viscy.transforms._gaussian_smooth import ( + BatchedRandGaussianSmooth, + BatchedRandGaussianSmoothd, +) +from viscy.transforms._noise import ( + BatchedRandGaussianNoise, + BatchedRandGaussianNoised, + RandGaussianNoiseTensor, + RandGaussianNoiseTensord, +) from viscy.transforms._redef import ( CenterSpatialCropd, Decollated, @@ -14,28 +36,61 @@ ScaleIntensityRangePercentilesd, ToDeviced, ) +from viscy.transforms._scale_intensity import ( + BatchedRandScaleIntensity, + BatchedRandScaleIntensityd, +) from viscy.transforms._transforms import ( BatchedRandAffined, + BatchedScaleIntensityRangePercentiles, BatchedScaleIntensityRangePercentilesd, BatchedZoom, NormalizeSampled, - RandGaussianNoiseTensord, RandInvertIntensityd, StackChannelsd, TiledSpatialCropSamplesd, ) +from viscy.transforms.batched_rand_3d_elasticd import BatchedRand3DElasticd +from viscy.transforms.batched_rand_histogram_shiftd import BatchedRandHistogramShiftd +from viscy.transforms.batched_rand_local_pixel_shufflingd import ( + BatchedRandLocalPixelShufflingd, +) +from viscy.transforms.batched_rand_sharpend import BatchedRandSharpend +from viscy.transforms.batched_rand_zstack_shiftd import BatchedRandZStackShiftd __all__ = [ + "BatchedCenterSpatialCrop", + "BatchedCenterSpatialCropd", + "BatchedRandAdjustContrast", + "BatchedRandAdjustContrastd", "BatchedRandAffined", + "BatchedRand3DElasticd", + "BatchedRandFlip", + "BatchedRandFlipd", + "BatchedRandGaussianSmooth", + "BatchedRandGaussianSmoothd", + "BatchedRandGaussianNoise", + "BatchedRandGaussianNoised", + "BatchedRandHistogramShiftd", + "BatchedRandLocalPixelShufflingd", + "BatchedRandScaleIntensity", + "BatchedRandScaleIntensityd", + "BatchedRandSharpend", + "BatchedRandSpatialCrop", + "BatchedRandSpatialCropd", + "BatchedRandZStackShiftd", + "BatchedScaleIntensityRangePercentiles", "BatchedScaleIntensityRangePercentilesd", "BatchedZoom", "CenterSpatialCropd", + "Decollate", "Decollated", "NormalizeSampled", "RandAdjustContrastd", "RandAffined", "RandFlipd", "RandGaussianNoised", + "RandGaussianNoiseTensor", "RandGaussianNoiseTensord", "RandGaussianSmoothd", "RandInvertIntensityd", diff --git a/viscy/transforms/_transforms.py b/viscy/transforms/_transforms.py index 50ad7a812..0a32e1edd 100644 --- a/viscy/transforms/_transforms.py +++ b/viscy/transforms/_transforms.py @@ -8,8 +8,6 @@ from monai.transforms import ( MapTransform, MultiSampleTrait, - RandGaussianNoise, - RandGaussianNoised, RandomizableTransform, ScaleIntensityRangePercentiles, Transform, @@ -421,58 +419,3 @@ def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: d[key] = self.random_affine(data) assert d[key].device == data.device return d - - -class RandGaussianNoiseTensor(RandGaussianNoise): - """Rand Gaussian Noise Tensor.""" - - def randomize(self, img: Tensor, mean: float | None = None) -> None: - self._do_transform = self.R.rand() < self.prob - if not self._do_transform: - return None - std = self.R.uniform(0, self.std) if self.sample_std else self.std - self.noise = torch.normal( - self.mean if mean is None else mean, - std, - size=img.shape, - device=img.device, - dtype=img.dtype, - ) - - -class RandGaussianNoiseTensord(RandGaussianNoised): - """Rand Gaussian Noise Tensor. - - Parameters - ---------- - keys : str | Iterable[str] - Keys to noise. - prob : float, optional - Probability of noise. By default, 0.1. - mean : float, optional - Mean. By default, 0.0. - std : float, optional - Standard deviation. By default, 0.1. - dtype : DTypeLike, optional - Data type. By default, np.float32. - allow_missing_keys : bool, optional - Whether to allow missing keys. By default, False. - sample_std : bool, optional - Whether to sample the standard deviation. By default, True. - """ - - def __init__( - self, - keys: str | Iterable[str], - prob: float = 0.1, - mean: float = 0.0, - std: float = 0.1, - dtype: DTypeLike = np.float32, - allow_missing_keys: bool = False, - sample_std: bool = True, - ) -> None: - MapTransform.__init__(self, keys, allow_missing_keys) - RandomizableTransform.__init__(self, prob) - self.rand_gaussian_noise = RandGaussianNoiseTensor( - mean=mean, std=std, prob=1.0, dtype=dtype, sample_std=sample_std - ) From 58aef5471b6a053bf716de3ca1f7e716c3fa4be5 Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Wed, 17 Sep 2025 18:05:40 -0700 Subject: [PATCH 12/13] updated some batched transforms docs --- tests/transforms/test_adjust_contrast.py | 1 - tests/transforms/test_crop.py | 1 - tests/transforms/test_flip.py | 1 - tests/transforms/test_gaussian_smooth.py | 1 - tests/transforms/test_noise.py | 1 - tests/transforms/test_scale_intensity.py | 1 - tests/transforms/test_transforms.py | 1 - viscy/data/triplet.py | 3 +- viscy/transforms/batched_rand_3d_elasticd.py | 36 ++++++++++++++++++- .../batched_rand_histogram_shiftd.py | 28 ++++++++++++++- .../batched_rand_local_pixel_shufflingd.py | 30 +++++++++++++++- viscy/transforms/batched_rand_sharpend.py | 28 ++++++++++++++- .../transforms/batched_rand_zstack_shiftd.py | 32 ++++++++++++++++- viscy/utils/blend.py | 16 +++++++++ 14 files changed, 166 insertions(+), 14 deletions(-) diff --git a/tests/transforms/test_adjust_contrast.py b/tests/transforms/test_adjust_contrast.py index d2cd6e9cc..a40538162 100644 --- a/tests/transforms/test_adjust_contrast.py +++ b/tests/transforms/test_adjust_contrast.py @@ -1,7 +1,6 @@ import pytest import torch from monai.transforms import AdjustContrast, Compose - from viscy.transforms import BatchedRandAdjustContrast, BatchedRandAdjustContrastd diff --git a/tests/transforms/test_crop.py b/tests/transforms/test_crop.py index e5f80ae04..ca26b9efe 100644 --- a/tests/transforms/test_crop.py +++ b/tests/transforms/test_crop.py @@ -1,7 +1,6 @@ import pytest import torch from monai.transforms import Compose - from viscy.transforms._crop import ( BatchedCenterSpatialCrop, BatchedCenterSpatialCropd, diff --git a/tests/transforms/test_flip.py b/tests/transforms/test_flip.py index 0fbd1bf5a..eb4596054 100644 --- a/tests/transforms/test_flip.py +++ b/tests/transforms/test_flip.py @@ -1,6 +1,5 @@ import pytest import torch - from viscy.transforms import BatchedRandFlip, BatchedRandFlipd diff --git a/tests/transforms/test_gaussian_smooth.py b/tests/transforms/test_gaussian_smooth.py index 64bd979aa..136c1c55d 100644 --- a/tests/transforms/test_gaussian_smooth.py +++ b/tests/transforms/test_gaussian_smooth.py @@ -7,7 +7,6 @@ get_gaussian_kernel3d, ) from monai.transforms.intensity.array import GaussianSmooth - from viscy.transforms import BatchedRandGaussianSmooth, BatchedRandGaussianSmoothd from viscy.transforms._gaussian_smooth import filter3d_separable diff --git a/tests/transforms/test_noise.py b/tests/transforms/test_noise.py index da9a1e9fd..5e58e5c9b 100644 --- a/tests/transforms/test_noise.py +++ b/tests/transforms/test_noise.py @@ -1,7 +1,6 @@ import pytest import torch from monai.transforms import Compose - from viscy.transforms import BatchedRandGaussianNoise, BatchedRandGaussianNoised diff --git a/tests/transforms/test_scale_intensity.py b/tests/transforms/test_scale_intensity.py index 2cfdaa954..03d323074 100644 --- a/tests/transforms/test_scale_intensity.py +++ b/tests/transforms/test_scale_intensity.py @@ -1,7 +1,6 @@ import pytest import torch from monai.transforms import RandScaleIntensity - from viscy.transforms import BatchedRandScaleIntensity, BatchedRandScaleIntensityd diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 88b955e3b..8e3efb481 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -1,6 +1,5 @@ import pytest import torch - from viscy.transforms._decollate import Decollate from viscy.transforms._transforms import ( BatchedScaleIntensityRangePercentiles, diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index bf9ad7ee4..8110e6096 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -207,8 +207,7 @@ def _get_tensorstore(self, position: Position) -> ts.TensorStore: return self._tensorstores[fov_name] def _filter_tracks(self, tracks_tables: list[pd.DataFrame]) -> pd.DataFrame: - """Exclude tracks that are too close to the border - or do not have the next time point. + """Exclude tracks that are too close to the border or do not have the next time point. Parameters ---------- diff --git a/viscy/transforms/batched_rand_3d_elasticd.py b/viscy/transforms/batched_rand_3d_elasticd.py index 7e38f01d1..c28ebcbea 100644 --- a/viscy/transforms/batched_rand_3d_elasticd.py +++ b/viscy/transforms/batched_rand_3d_elasticd.py @@ -5,7 +5,29 @@ class BatchedRand3DElasticd(MapTransform, RandomizableTransform): - """Batched 3D elastic deformation for biological structures.""" + """Apply random 3D elastic deformation image data. + + Uses Gaussian-smoothed displacement fields to simulate natural tissue deformation. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + sigma_range : tuple[float, float] + Range for random sigma values used in Gaussian smoothing. + magnitude_range : tuple[float, float] + Range for random displacement magnitude values. + spatial_size : tuple[int, int, int] or int or None, optional + Expected spatial size of input data. + prob : float, optional + Probability of applying the transform, by default 0.1. + mode : str, optional + Interpolation mode for grid sampling, by default "bilinear". + padding_mode : str, optional + Padding mode for grid sampling, by default "reflection". + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -76,6 +98,18 @@ def _generate_elastic_field( return torch.stack(displacement_fields) def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply elastic deformation to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with transformed tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/transforms/batched_rand_histogram_shiftd.py b/viscy/transforms/batched_rand_histogram_shiftd.py index e7fe2b39d..60880a4c3 100644 --- a/viscy/transforms/batched_rand_histogram_shiftd.py +++ b/viscy/transforms/batched_rand_histogram_shiftd.py @@ -5,7 +5,21 @@ class BatchedRandHistogramShiftd(MapTransform, RandomizableTransform): - """Batched random histogram shifting for intensity distribution changes.""" + """ + + Apply random histogram shifts to modify intensity distributions. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + shift_range : tuple[float, float], optional + Range for random intensity shift values, by default (-0.1, 0.1). + prob : float, optional + Probability of applying the transform, by default 0.1. + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -19,6 +33,18 @@ def __init__( self.shift_range = shift_range def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply histogram shift to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with intensity-shifted tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/transforms/batched_rand_local_pixel_shufflingd.py b/viscy/transforms/batched_rand_local_pixel_shufflingd.py index 73cd9caf4..8cebc41e1 100644 --- a/viscy/transforms/batched_rand_local_pixel_shufflingd.py +++ b/viscy/transforms/batched_rand_local_pixel_shufflingd.py @@ -5,7 +5,23 @@ class BatchedRandLocalPixelShufflingd(MapTransform, RandomizableTransform): - """Batched random local pixel shuffling for texture augmentation.""" + """Apply random local pixel shuffling to simulate texture variations. + + Shuffles pixels within small local patches to add texture noise. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + patch_size : int, optional + Size of local patches for pixel shuffling, by default 3. + shuffle_prob : float, optional + Probability of shuffling within patches, by default 0.1. + prob : float, optional + Probability of applying the transform, by default 0.1. + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -72,6 +88,18 @@ def _shuffle_patches(self, data: Tensor) -> Tensor: return result def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply pixel shuffling to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with pixel-shuffled tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/transforms/batched_rand_sharpend.py b/viscy/transforms/batched_rand_sharpend.py index fe2a54c07..92b992b5f 100644 --- a/viscy/transforms/batched_rand_sharpend.py +++ b/viscy/transforms/batched_rand_sharpend.py @@ -6,7 +6,21 @@ class BatchedRandSharpend(MapTransform, RandomizableTransform): - """Batched random sharpening for microscopy images.""" + """Apply random sharpening to enhance image edges and details. + + Uses 3D convolution with sharpening kernel to enhance fine structures. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + alpha_range : tuple[float, float], optional + Range for random alpha blending values, by default (0.1, 0.5). + prob : float, optional + Probability of applying the transform, by default 0.1. + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -40,6 +54,18 @@ def _get_sharpen_kernel(self, device: torch.device, channels: int) -> Tensor: return self._cached_kernel def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply sharpening to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with sharpened tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/transforms/batched_rand_zstack_shiftd.py b/viscy/transforms/batched_rand_zstack_shiftd.py index e94e5b714..1fbad6638 100644 --- a/viscy/transforms/batched_rand_zstack_shiftd.py +++ b/viscy/transforms/batched_rand_zstack_shiftd.py @@ -5,7 +5,25 @@ class BatchedRandZStackShiftd(MapTransform, RandomizableTransform): - """Batched random Z-axis shifts for 3D microscopy data.""" + """Apply random shifts along Z-axis to simulate focal plane variations. + + Shifts image data in the depth dimension to augment focal plane diversity. + + Parameters + ---------- + keys : str or Iterable[str] + Keys of the corresponding items to be transformed. + max_shift : int, optional + Maximum shift distance in Z direction, by default 3. + prob : float, optional + Probability of applying the transform, by default 0.1. + mode : str, optional + Padding mode for shifted regions, by default "constant". + cval : float, optional + Fill value for constant padding, by default 0.0. + allow_missing_keys : bool, optional + Whether to ignore missing keys, by default False. + """ def __init__( self, @@ -23,6 +41,18 @@ def __init__( self.cval = cval def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + """Apply Z-axis shift to sample data. + + Parameters + ---------- + sample : dict[str, Tensor] + Dictionary containing image tensors to transform. + + Returns + ------- + dict[str, Tensor] + Dictionary with Z-shifted tensors. + """ self.randomize(None) d = dict(sample) diff --git a/viscy/utils/blend.py b/viscy/utils/blend.py index 18167f852..006ee59a3 100644 --- a/viscy/utils/blend.py +++ b/viscy/utils/blend.py @@ -6,6 +6,22 @@ def blend_channels( image: np.ndarray, cmaps: list[Colormap], rescale: bool ) -> np.ndarray: + """Blend multi-channel images using specified colormaps. + + Parameters + ---------- + image : np.ndarray + Multi-channel image array to blend. + cmaps : list[Colormap] + List of colormaps for each channel. + rescale : bool + Whether to rescale intensity values to [0, 1] range. + + Returns + ------- + np.ndarray + Blended RGB image clipped to [0, 1] range. + """ rendered_channels = [] for channel, cmap in zip(image, cmaps): colormap = Colormap(cmap) From 69ee64526b9b3867f59915fb090a99a6fd242b3d Mon Sep 17 00:00:00 2001 From: Sricharan Reddy Varra Date: Wed, 17 Sep 2025 18:06:18 -0700 Subject: [PATCH 13/13] ruff format --- viscy/transforms/batched_rand_histogram_shiftd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/transforms/batched_rand_histogram_shiftd.py b/viscy/transforms/batched_rand_histogram_shiftd.py index 60880a4c3..c3ee2d332 100644 --- a/viscy/transforms/batched_rand_histogram_shiftd.py +++ b/viscy/transforms/batched_rand_histogram_shiftd.py @@ -6,7 +6,7 @@ class BatchedRandHistogramShiftd(MapTransform, RandomizableTransform): """ - + Apply random histogram shifts to modify intensity distributions. Parameters