Skip to content

Commit 8a95948

Browse files
committed
✅ Check for cache_mode with zarr output.
1 parent 91a45b3 commit 8a95948

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

tests/engines/test_semantic_segmentor.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
from pathlib import Path
66

7+
import numpy as np
78
import torch
9+
import zarr
810

911
from tiatoolbox.models.engine.semantic_segmentor_new import SemanticSegmentor
1012
from tiatoolbox.utils import env_detection as toolbox_env
@@ -20,7 +22,9 @@ def test_semantic_segmentor_init() -> None:
2022
assert isinstance(segmentor.model, torch.nn.Module)
2123

2224

23-
def test_semantic_segmentor_patches(sample_patch1: Path, sample_patch2: Path) -> None:
25+
def test_semantic_segmentor_patches(
26+
sample_patch1: Path, sample_patch2: Path, tmp_path: Path
27+
) -> None:
2428
"""Tests SemanticSegmentor on image patches."""
2529
segmentor = SemanticSegmentor(
2630
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
@@ -38,6 +42,9 @@ def test_semantic_segmentor_patches(sample_patch1: Path, sample_patch2: Path) ->
3842
patch_mode=True,
3943
)
4044

45+
assert 0.24 < np.mean(output["predictions"][:]) < 0.25
46+
assert 0.495 < np.mean(output["probabilities"][:]) < 0.505
47+
4148
assert (
4249
tuple(segmentor._ioconfig.patch_output_shape)
4350
== output["probabilities"][0].shape[:-1]
@@ -46,3 +53,35 @@ def test_semantic_segmentor_patches(sample_patch1: Path, sample_patch2: Path) ->
4653
assert (
4754
tuple(segmentor._ioconfig.patch_output_shape) == output["predictions"][0].shape
4855
)
56+
57+
output = segmentor.run(
58+
images=inputs,
59+
return_probabilities=True,
60+
return_labels=False,
61+
device=device,
62+
patch_mode=True,
63+
cache_mode=True,
64+
save_dir=tmp_path / "output0",
65+
)
66+
67+
assert output == tmp_path / "output0" / "output.zarr"
68+
69+
output = zarr.open(output, mode="r")
70+
assert 0.24 < np.mean(output["predictions"][:]) < 0.25
71+
assert 0.495 < np.mean(output["probabilities"][:]) < 0.505
72+
73+
output = segmentor.run(
74+
images=inputs,
75+
return_probabilities=False,
76+
return_labels=False,
77+
device=device,
78+
patch_mode=True,
79+
cache_mode=True,
80+
save_dir=tmp_path / "output1",
81+
)
82+
83+
assert output == tmp_path / "output1" / "output.zarr"
84+
85+
output = zarr.open(output, mode="r")
86+
assert 0.24 < np.mean(output["predictions"][:]) < 0.25
87+
assert "probabilities" not in output.keys() # noqa: SIM118

0 commit comments

Comments
 (0)