Skip to content

✨ Define SemanticSegmentor with the New EngineABC #866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 189 commits into
base: dev-define-engines-abc
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
189 commits
Select commit Hold shift + click to select a range
193c587
:sparkles: Define `SemanticSegmentor` with the New `EngineABC`
shaneahmed Sep 20, 2024
113fb6b
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Nov 20, 2024
44e4b8c
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Nov 21, 2024
7e8d78b
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Nov 22, 2024
1ea6c80
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Dec 3, 2024
d0b86a5
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jan 3, 2025
399827a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
333264a
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jan 24, 2025
c84099b
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jan 24, 2025
16a632f
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Feb 5, 2025
e03db92
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Feb 21, 2025
13cc94a
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 3, 2025
dd83117
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 4, 2025
967dba1
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 7, 2025
40e79a1
:hammer: Add `run` method to SemanticSegmentor
shaneahmed Mar 7, 2025
85922ff
:memo: Update docstring
shaneahmed Mar 7, 2025
bda0581
:recycle: Refactor resolution to input_resolutions.
shaneahmed Mar 7, 2025
a643ea6
:recycle: Use `input_resolutions` instead of resolution
shaneahmed Mar 7, 2025
7eed649
:recycle: Use `input_resolutions` instead of resolution
shaneahmed Mar 7, 2025
3dff881
:white_check_mark: Add test to cli.
shaneahmed Mar 7, 2025
52ed249
Merge branch 'dev-use-input-resolutions' into dev-define-semantic-seg…
shaneahmed Mar 7, 2025
03e07e6
:white_check_mark: Add SemanticSegmentor patch_mode test.
shaneahmed Mar 7, 2025
a37e71d
:bug: Fix `unet` architecture
shaneahmed Mar 8, 2025
94a747f
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 8, 2025
987927d
Merge branch 'dev-define-engines-abc' into dev-use-input-resolutions
shaneahmed Mar 8, 2025
92bc813
:bug: Fix `test_datset` architecture
shaneahmed Mar 8, 2025
bddd956
Merge branch 'dev-use-input-resolutions' into dev-define-semantic-seg…
shaneahmed Mar 8, 2025
facf461
:white_check_mark: Add postproc to segmentation.
shaneahmed Mar 12, 2025
2b342f4
Merge remote-tracking branch 'origin/dev-define-semantic-segmentor' i…
shaneahmed Mar 12, 2025
6f9d412
:recycle: Move argmax postprocessing to utils.
shaneahmed Mar 12, 2025
91a45b3
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 14, 2025
8a95948
:white_check_mark: Check for cache_mode with zarr output.
shaneahmed Mar 17, 2025
cf5b50e
:white_check_mark: Update script for annotation store
shaneahmed Mar 18, 2025
816a568
:construction: Update script for annotation store
shaneahmed Mar 18, 2025
2f88cb7
:construction: Update script for annotation store
shaneahmed Mar 19, 2025
91d134d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2025
9204604
:bug: Fix saving annotation properties
shaneahmed Mar 19, 2025
84fdd2f
:bug: Fix TID252 Prefer absolute imports over relative imports
shaneahmed Mar 19, 2025
16e3f91
:bug: Test mask to store from #918
shaneahmed Mar 21, 2025
b7862c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2025
a495daf
:bug: Fix AnnotationStore display
shaneahmed Mar 28, 2025
355af0d
:fire: Remove unnecessary test
shaneahmed Mar 31, 2025
9e30cb9
:wastebasket: Clean up code for saving to AnnotationStore
shaneahmed Mar 31, 2025
d9e8f92
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Mar 31, 2025
9011a8f
:bug: Fix mypy checks
shaneahmed Apr 3, 2025
9e1c218
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Apr 3, 2025
136f9f2
:memo: Add comments
shaneahmed Apr 9, 2025
854e2b8
:white_check_mark: Update single and multipoint behaviour
shaneahmed Apr 10, 2025
fb56920
:white_check_mark: Add tests for correct annotation type
shaneahmed Apr 10, 2025
221e53b
:stethoscope: Add checks for correct patch size.
shaneahmed Apr 10, 2025
f69ac89
:white_check_mark: Add tests for incorrect image patch input.
shaneahmed Apr 10, 2025
e9ed0a9
:bug: Fix tests for io config delegation.
shaneahmed Apr 10, 2025
c1b06d5
:package: Add sample image for semantic_segmentor tests.
shaneahmed Apr 10, 2025
211d39d
:bug: Fix conversion for single and two points annotations.
shaneahmed Apr 10, 2025
f2b4678
:bug: Fix semantic segmentor test
shaneahmed Apr 10, 2025
388bbc6
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Apr 10, 2025
56e195c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 10, 2025
c904626
:bug: Fix `ruff` errors.
shaneahmed Apr 10, 2025
fae76f8
Merge remote-tracking branch 'origin/dev-define-semantic-segmentor' i…
shaneahmed Apr 10, 2025
7ce679c
:bug: Fix offset values
shaneahmed Apr 10, 2025
f613cef
:bug: Fix tests
shaneahmed Apr 10, 2025
45a5737
:white_check_mark: Add tests for input numpy array
shaneahmed Apr 11, 2025
ab4e93b
:white_check_mark: Add tests to improve coverage.
shaneahmed Apr 11, 2025
a3174dc
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Apr 25, 2025
322d32e
:memo: Fix typos
shaneahmed Apr 25, 2025
fc89d81
:sparkles: Save `probabilities` to `tiff`
shaneahmed Apr 25, 2025
4abcd6f
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed May 9, 2025
2f60b97
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed May 14, 2025
22db63f
:bug: Fix bugs after merge
shaneahmed May 15, 2025
fe0f0ad
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed May 19, 2025
5972fac
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jun 9, 2025
f2aa349
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jun 10, 2025
d3f6790
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jun 19, 2025
893abcb
:rewind: Revert padding with zeros in the output.
shaneahmed Jun 19, 2025
7cb05c2
:fire: Remove `mode` option from WSIPatchDataset
shaneahmed Jun 20, 2025
d032ebf
:hammer: Add patch_output_shape to WSIPatchDataset
shaneahmed Jun 20, 2025
a469c0f
:bulb: Add no cover to TYPE_CHECKING
shaneahmed Jun 20, 2025
51ba746
:hammer: Add `patch_output_shape` support.
shaneahmed Jun 20, 2025
49b997f
:hammer: Add output_locations as an attribute to SemanticSegmentor.
shaneahmed Jun 20, 2025
8f38244
:bug: Fix checking attribute for outputs in WSIPatchDataset
shaneahmed Jun 20, 2025
60e6959
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Jun 24, 2025
c6926f1
:twisted_rightwards_arrows: Merge dev-define-engine into dev-define-s…
shaneahmed Jul 9, 2025
31bb170
:memo: Fix `SemanticSegmentorRunParams`
shaneahmed Jul 10, 2025
db4788f
:alien: Add dataloader as an attribute to EngineABC
shaneahmed Jul 10, 2025
bfd02b5
:sparkles: Add merge_predictions and save to zarr
shaneahmed Jul 10, 2025
d08146e
:white_check_mark: Add test for coverage
shaneahmed Jul 10, 2025
ed52e08
:white_check_mark: Using small image for faster run
shaneahmed Jul 11, 2025
4f353f8
:white_check_mark: Improve Tests for WSI output
shaneahmed Jul 11, 2025
2b790b8
:white_check_mark: Improve Tests for WSI output
shaneahmed Jul 11, 2025
ceeda81
:white_check_mark: Add test for a large image
shaneahmed Jul 24, 2025
fe9f2e3
:lipstick: Add a function to misc to return an appropriate tqdm object.
shaneahmed Jul 24, 2025
7d02703
:lipstick: Add a function to misc to return an appropriate tqdm object.
shaneahmed Jul 24, 2025
efe0c80
:lipstick: Add a function to misc to return an appropriate tqdm object.
shaneahmed Jul 24, 2025
6d9b3c7
:zap: Improve patch output merging.
shaneahmed Jul 24, 2025
4fa584f
:sparkles: Add saving to AnnotationStore
shaneahmed Jul 24, 2025
3aaed3e
:fire: Purge old semantic_segmentor.py
shaneahmed Jul 24, 2025
6ccb300
:rewind: Add `DeepFeatureExtractor`.
shaneahmed Jul 25, 2025
3dd28bc
:rewind: Add `DeepFeatureExtractor`.
shaneahmed Jul 25, 2025
7651eda
:bug: Fix SemanticSegmentor import
shaneahmed Jul 25, 2025
a2d7fb8
:lipstick: Cosmetic changes
shaneahmed Jul 25, 2025
6296069
:zap: Using larger chunk size improves performance.
shaneahmed Jul 25, 2025
358c67e
:zap: Optimize performance
shaneahmed Jul 28, 2025
84a21cf
:bug: Fix tests
shaneahmed Jul 28, 2025
30621fc
:bug: Fix tests
shaneahmed Jul 29, 2025
55780a7
:fire: Remove smart_divide
shaneahmed Jul 29, 2025
8ab5293
:construction: Use dask for processing.
shaneahmed Jul 29, 2025
ef3c9f6
:construction: Use dask for processing.
shaneahmed Jul 29, 2025
61aee15
:bug: Fix bug saving annotation store.
shaneahmed Jul 29, 2025
4699643
:white_check_mark: Add test for TypeError
shaneahmed Jul 30, 2025
67fc4e2
:bug: Fix saving to zarr output
shaneahmed Jul 30, 2025
621adb5
:bug: Fix saving to annotationstore output
shaneahmed Jul 30, 2025
1d6b877
:zap: Improve merge performance
shaneahmed Jul 30, 2025
33d3c97
:zap: Using dask compute
shaneahmed Jul 31, 2025
2ab6238
:zap: Udate EngineABC for delayed compute.
shaneahmed Jul 31, 2025
2fd6623
:test_tube: Failing checks.
shaneahmed Aug 4, 2025
22fa379
:white_check_mark: Check tests for new dask based implementation.
shaneahmed Aug 4, 2025
69ce1b4
:fire: Remove self.keys
shaneahmed Aug 4, 2025
c39eece
:white_check_mark: Update for PatchPredictor
shaneahmed Aug 4, 2025
aa7e0c8
:bug: Fix bugs related to dictionary output.
shaneahmed Aug 4, 2025
e0f90f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2025
759d68e
Merge branch 'dev-define-engines-abc' into dev-define-semantic-segmentor
shaneahmed Aug 4, 2025
c4988c2
:bug: Fix ruff bug
shaneahmed Aug 4, 2025
4c4eff0
:white_check_mark: Fix failing test
shaneahmed Aug 4, 2025
70d7845
:bug: Fix bug and update documentation.
shaneahmed Aug 4, 2025
9ab67b0
:bug: Fix prepare_engines_save_dir
shaneahmed Aug 5, 2025
a7db423
:bug: Fix out_ shape
shaneahmed Aug 5, 2025
7f5d767
:memo: Update EngineABC documentation.
shaneahmed Aug 5, 2025
c05d0f3
:memo: Update _initialize_model_ioconfig documentation.
shaneahmed Aug 5, 2025
b12344b
:memo: Update get_dataloader documentation.
shaneahmed Aug 5, 2025
df1908a
:memo: Update _update_model_output documentation.
shaneahmed Aug 5, 2025
56db667
:memo: Update _get_coordinates documentation.
shaneahmed Aug 5, 2025
2561aad
:memo: Update process_batch documentation.
shaneahmed Aug 5, 2025
08c3dd5
:memo: Update infer_patches and post_process_patches.
shaneahmed Aug 5, 2025
6340091
:memo: Update save_predictions documentation.
shaneahmed Aug 5, 2025
5c63b3f
:memo: Update infer_wsi documentation.
shaneahmed Aug 5, 2025
b310930
:memo: Update post_process_wsi documentation.
shaneahmed Aug 5, 2025
52f3c8d
:memo: Update _load_ioconfig documentation.
shaneahmed Aug 5, 2025
ae7e4d6
:memo: Update _update_ioconfig documentation.
shaneahmed Aug 5, 2025
9f16e58
:memo: Update _validate_images_masks and _validate_input_numbers docu…
shaneahmed Aug 5, 2025
39e42a5
:memo: Update _update_run_params documentation.
shaneahmed Aug 5, 2025
8814001
:memo: Update _run_patch_mode documentation.
shaneahmed Aug 5, 2025
addce82
:memo: Update _calculate_scale_factor documentation.
shaneahmed Aug 5, 2025
4e748cd
:memo: Update _run_wsi_mode documentation.
shaneahmed Aug 5, 2025
6d882f4
:memo: Update run documentation.
shaneahmed Aug 5, 2025
1c57b55
:memo: Update engine_abc.py documentation.
shaneahmed Aug 5, 2025
3de2392
:memo: Update PatchPredictor __init__ documentation.
shaneahmed Aug 5, 2025
fe6d52e
:memo: Update PredictorRunParams documentation.
shaneahmed Aug 5, 2025
efc8106
:memo: Update PredictorRunParams documentation.
shaneahmed Aug 5, 2025
65939a0
:memo: Update post_process_patches documentation.
shaneahmed Aug 5, 2025
ba0bbba
:memo: Update post_process_wsi documentation.
shaneahmed Aug 5, 2025
9ad2d54
:memo: Update _update_run_params documentation.
shaneahmed Aug 5, 2025
0fae659
:memo: Update run documentation.
shaneahmed Aug 5, 2025
f813ffd
- Update semantic_segmentor.py using dask
shaneahmed Aug 5, 2025
3730fef
:bug: Fix unet test
shaneahmed Aug 6, 2025
dc3e6e1
:zap: Use basic loop to update.
shaneahmed Aug 6, 2025
edde87a
:fire: Remove intermediate calculation
shaneahmed Aug 6, 2025
29aeb1b
:fire: Remove unnecessary function
shaneahmed Aug 6, 2025
b22c303
:fire: Remove unnecessary function
shaneahmed Aug 6, 2025
d49c3d9
:fire: Remove unnecessary code
shaneahmed Aug 6, 2025
e9ac4e2
Merge remote-tracking branch 'origin/dev-define-semantic-segmentor' i…
shaneahmed Aug 7, 2025
daa1559
:zap: Update Segmentation Merge Code
shaneahmed Aug 7, 2025
04b9716
:zap: Update fcn_resnet50_unet-bcss stride shape
shaneahmed Aug 8, 2025
53792a7
:zap: Add merge_all to semantic_segmentor.py
shaneahmed Aug 8, 2025
4efa83d
:zap: Update SemanticSegmentor for merging logic
shaneahmed Aug 8, 2025
d202756
:memo: Fix annotations
shaneahmed Aug 8, 2025
f1b622e
:zap: Implement hybrid approach
shaneahmed Aug 11, 2025
daf7202
:memo: Update `infer_patches` documentation
shaneahmed Aug 11, 2025
45c6815
:memo: Update `__init__` documentation
shaneahmed Aug 11, 2025
a54ff40
:memo: Update `merge_all` documentation
shaneahmed Aug 11, 2025
ce6c8c0
:memo: Update `SemanticSegmentorRunParams` documentation
shaneahmed Aug 11, 2025
eed51a5
:memo: Update `SemanticSegmentor` documentation
shaneahmed Aug 11, 2025
3804524
:memo: Update `get_dataloader` documentation
shaneahmed Aug 11, 2025
6310ab5
:memo: Update `infer_wsi` documentation
shaneahmed Aug 11, 2025
f97a55c
:memo: Update `save_predictions` documentation
shaneahmed Aug 11, 2025
9f9a709
:memo: Update `run` documentation
shaneahmed Aug 11, 2025
775aa0a
:fire: Remove cache_mode
shaneahmed Aug 11, 2025
ee80501
:fire: Remove feature_extractor.py.
shaneahmed Aug 11, 2025
6e8819f
:zap: Rechunking strategy
shaneahmed Aug 12, 2025
fe060dd
:twisted_rightwards_arrows: Merge `dev-engine-abc` into `dev-define-s…
shaneahmed Aug 12, 2025
37ac626
:bug: Fix error after merge
shaneahmed Aug 12, 2025
5cda7c2
:bug: Fix error after merge
shaneahmed Aug 12, 2025
3150e32
:rewind: Revert previous commit
shaneahmed Aug 12, 2025
1b957d2
:zap: Improved run on bcss dataset by 4 times.
shaneahmed Aug 12, 2025
677284b
:recycle: Rename functions for clarity
shaneahmed Aug 13, 2025
b1763c9
:zap: Improve performance.
shaneahmed Aug 13, 2025
81d1d05
:white_check_mark: Improve coverage and address deepsource bugs
shaneahmed Aug 13, 2025
80691d1
:bug: Fix deepsource error
shaneahmed Aug 13, 2025
5b4a505
:zap: Move model to device in _update_run_params
shaneahmed Aug 13, 2025
708e706
:white_check_mark: Add tests to improve coverage
shaneahmed Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ def sample_wsi_dict(remote_sample: Callable) -> dict:
"wsi4_4k_4k_svs",
"wsi3_20k_20k_pred",
"wsi4_4k_4k_pred",
"wsi4_1k_1k_svs",
]
return {name: remote_sample(name) for name in file_names}

Expand Down
14 changes: 7 additions & 7 deletions tests/engines/test_engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
eng = TestEngineABC(model=model)

kwargs = {
"patch_input_shape": [512, 512],
"patch_input_shape": [224, 224],
"input_resolutions": [{"units": "mpp", "resolution": 1.75}],
}
with caplog.at_level(logging.WARNING):
Expand All @@ -536,7 +536,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)

# test providing config / full input info for non pretrained models
ioconfig = ModelIOConfigABC(
patch_input_shape=(512, 512),
patch_input_shape=(224, 224),
stride_shape=(256, 256),
input_resolutions=[{"resolution": 1.35, "units": "mpp"}],
)
Expand All @@ -546,7 +546,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
save_dir=f"{tmp_path}/dump",
ioconfig=ioconfig,
)
assert eng._ioconfig.patch_input_shape == (512, 512)
assert eng._ioconfig.patch_input_shape == (224, 224)
assert eng._ioconfig.stride_shape == (256, 256)
assert eng._ioconfig.input_resolutions == [{"resolution": 1.35, "units": "mpp"}]
shutil.rmtree(tmp_path / "dump", ignore_errors=True)
Expand All @@ -557,15 +557,15 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
save_dir=f"{tmp_path}/dump",
**kwargs,
)
assert eng._ioconfig.patch_input_shape == [512, 512]
assert eng._ioconfig.stride_shape == [512, 512]
assert eng._ioconfig.patch_input_shape == [224, 224]
assert eng._ioconfig.stride_shape == [224, 224]
assert eng._ioconfig.input_resolutions == [{"resolution": 1.75, "units": "mpp"}]
shutil.rmtree(tmp_path / "dump", ignore_errors=True)

# test overwriting pretrained ioconfig
eng = TestEngineABC(model="alexnet-kather100k")
eng.run(
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
images=np.zeros((10, 300, 300, 3), dtype=np.uint8),
patch_input_shape=(300, 300),
stride_shape=(300, 300),
input_resolutions=[{"units": "baseline", "resolution": 1.99}],
Expand All @@ -579,7 +579,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture)
shutil.rmtree(tmp_path / "dump", ignore_errors=True)

eng.run(
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
images=np.zeros((10, 300, 300, 3), dtype=np.uint8),
patch_input_shape=(300, 300),
stride_shape=(300, 300),
input_resolutions=None,
Expand Down
193 changes: 193 additions & 0 deletions tests/engines/test_semantic_segmentor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Test SemanticSegmentor."""

from __future__ import annotations

import json
import sqlite3
from typing import TYPE_CHECKING, Callable

import numpy as np
import torch
import zarr

from tiatoolbox.annotation import SQLiteStore
from tiatoolbox.models.engine.semantic_segmentor_new import SemanticSegmentor
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.misc import imread

if TYPE_CHECKING:
from pathlib import Path

device = "cuda" if toolbox_env.has_gpu() else "cpu"


def test_semantic_segmentor_init() -> None:
"""Tests SemanticSegmentor initialization."""
segmentor = SemanticSegmentor(model="fcn-tissue_mask", device=device)

assert isinstance(segmentor, SemanticSegmentor)
assert isinstance(segmentor.model, torch.nn.Module)


def test_semantic_segmentor_patches(remote_sample: Callable, tmp_path: Path) -> None:
"""Tests SemanticSegmentor on image patches."""
segmentor = SemanticSegmentor(
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
)

sample_image = remote_sample("thumbnail-1k-1k")

inputs = [sample_image, sample_image]

assert segmentor.cache_mode is False

output = segmentor.run(
images=inputs,
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=True,
)

assert 0.15 < np.mean(output["predictions"][:]) < 0.18
assert 0.495 < np.mean(output["probabilities"][:]) < 0.505

assert (
tuple(segmentor._ioconfig.patch_output_shape)
== output["probabilities"][0].shape[:-1]
)

assert (
tuple(segmentor._ioconfig.patch_input_shape) == output["predictions"][0].shape
)

output = segmentor.run(
images=inputs,
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=True,
save_dir=tmp_path / "output0",
)

assert output == tmp_path / "output0" / "output.zarr"

output = zarr.open(output, mode="r")
assert 0.15 < np.mean(output["predictions"][:]) < 0.18
assert 0.495 < np.mean(output["probabilities"][:]) < 0.505

output = segmentor.run(
images=inputs,
return_probabilities=False,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=True,
output_type="zarr",
save_dir=tmp_path / "output1",
)

assert output == tmp_path / "output1" / "output.zarr"

output = zarr.open(output, mode="r")
assert 0.15 < np.mean(output["predictions"][:]) < 0.18
assert "probabilities" not in output.keys() # noqa: SIM118

output = segmentor.run(
images=inputs,
return_probabilities=False,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=False,
save_dir=tmp_path / "output2",
output_type="zarr",
)

assert output == tmp_path / "output2" / "output.zarr"

output = zarr.open(output, mode="r")
assert 0.15 < np.mean(output["predictions"][:]) < 0.18
assert "probabilities" not in output
assert "predictions" in output


def _test_store_output_patch(output: Path) -> None:
"""Helper method to test annotation store output for a patch."""
store_ = SQLiteStore.open(output)
annotations_ = store_.values()
annotations_geometry_type = [
str(annotation_.geometry_type) for annotation_ in annotations_
]
assert "Polygon" in annotations_geometry_type

con = sqlite3.connect(output)
cur = con.cursor()
annotations_properties = list(cur.execute("SELECT properties FROM annotations"))

out = []

for item in annotations_properties:
for json_str in item:
probs = json.loads(json_str)
if "type" in probs:
out.append(probs.pop("type"))

assert "mask" in out

assert annotations_properties is not None


def test_save_annotation_store(remote_sample: Callable, tmp_path: Path) -> None:
"""Test for saving output as annotation store."""
segmentor = SemanticSegmentor(
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
)

sample_image = remote_sample("thumbnail-1k-1k")

inputs = [sample_image]

output = segmentor.run(
images=inputs,
return_probabilities=False,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=True,
save_dir=tmp_path / "output1",
output_type="annotationstore",
)

assert output[0] == tmp_path / "output1" / (sample_image.stem + ".db")
_test_store_output_patch(output[0])


def test_save_annotation_store_nparray(remote_sample: Callable, tmp_path: Path) -> None:
"""Test for saving output as annotation store using a numpy array."""
segmentor = SemanticSegmentor(
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
)

sample_image = remote_sample("thumbnail-1k-1k")

input_image = imread(sample_image)
inputs_list = [input_image, input_image]

output = segmentor.run(
images=inputs_list,
return_probabilities=False,
return_labels=False,
device=device,
patch_mode=True,
cache_mode=True,
save_dir=tmp_path / "output1",
output_type="annotationstore",
)

assert output[0] == tmp_path / "output1" / "0.db"
assert output[1] == tmp_path / "output1" / "1.db"

_test_store_output_patch(output[0])
_test_store_output_patch(output[1])
2 changes: 1 addition & 1 deletion tests/models/test_arch_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_functional_unet(remote_sample: Callable) -> None:
pretrained = torch.load(pretrained_weights, map_location="cpu")
model.load_state_dict(pretrained)
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
_ = output[0]
_ = output["probabilities"][0]

# run untrained network to test for architecture
model = UNetModel(
Expand Down
29 changes: 23 additions & 6 deletions tests/models/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from tiatoolbox.utils import download_data, imread, imwrite, unzip_data
from tiatoolbox.utils import env_detection as toolbox_env
from tiatoolbox.utils.exceptions import DimensionMismatchError
from tiatoolbox.wsicore import WSIReader

RNG = np.random.default_rng() # Numpy Random Generator
Expand Down Expand Up @@ -120,7 +121,9 @@ def test_kather_dataset(tmp_path: Path) -> None:
assert len(dataset.inputs) == len(dataset.labels)

# to actually get the image, we feed it to PatchDataset
actual_ds = PatchDataset(dataset.inputs, dataset.labels)
actual_ds = PatchDataset(
dataset.inputs, dataset.labels, patch_input_shape=(224, 224)
)
sample_patch = actual_ds[89]
assert isinstance(sample_patch["image"], np.ndarray)
assert sample_patch["label"] is not None
Expand All @@ -129,14 +132,28 @@ def test_kather_dataset(tmp_path: Path) -> None:
shutil.rmtree(save_dir_path, ignore_errors=True)


def test_incorrect_input_shape() -> None:
"""Incorrect input patch dimensions should raise DimensionMismatchError."""
size = (5, 5, 3)
img = RNG.integers(low=0, high=255, size=size)
list_imgs = [img, img, img]
dataset = PatchDataset(list_imgs, patch_input_shape=(100, 100))
with pytest.raises(
DimensionMismatchError, match=r".*\(100, 100\), but got \(5, 5\).*"
):
_ = dataset[0]


def test_patch_dataset_path_imgs(
sample_patch1: str | Path,
sample_patch2: str | Path,
) -> None:
"""Test for patch dataset with a list of file paths as input."""
size = (224, 224, 3)

dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)])
dataset = PatchDataset(
[Path(sample_patch1), Path(sample_patch2)], patch_input_shape=size[:-1]
)

for _, sample_data in enumerate(dataset):
sampled_img_shape = sample_data["image"].shape
Expand All @@ -152,7 +169,7 @@ def test_patch_dataset_list_imgs(tmp_path: Path) -> None:
size = (5, 5, 3)
img = RNG.integers(low=0, high=255, size=size)
list_imgs = [img, img, img]
dataset = PatchDataset(list_imgs)
dataset = PatchDataset(list_imgs, patch_input_shape=size[:-1])

dataset.preproc_func = lambda x: x

Expand Down Expand Up @@ -197,14 +214,14 @@ def test_patch_datasetarray_imgs() -> None:
array_imgs = np.array(list_imgs)

# test different setter for label
dataset = PatchDataset(array_imgs, labels=labels)
dataset = PatchDataset(array_imgs, labels=labels, patch_input_shape=(5, 5))
an_item = dataset[2]
assert an_item["label"] == 3
dataset = PatchDataset(array_imgs, labels=None)
dataset = PatchDataset(array_imgs, labels=None, patch_input_shape=(5, 5))
an_item = dataset[2]
assert "label" not in an_item

dataset = PatchDataset(array_imgs)
dataset = PatchDataset(array_imgs, patch_input_shape=size[:-1])
for _, sample_data in enumerate(dataset):
sampled_img_shape = sample_data["image"].shape
assert sampled_img_shape[0] == size[0]
Expand Down
Loading