4
4
5
5
from pathlib import Path
6
6
7
+ import numpy as np
7
8
import torch
9
+ import zarr
8
10
9
11
from tiatoolbox .models .engine .semantic_segmentor_new import SemanticSegmentor
10
12
from tiatoolbox .utils import env_detection as toolbox_env
@@ -20,7 +22,9 @@ def test_semantic_segmentor_init() -> None:
20
22
assert isinstance (segmentor .model , torch .nn .Module )
21
23
22
24
23
- def test_semantic_segmentor_patches (sample_patch1 : Path , sample_patch2 : Path ) -> None :
25
+ def test_semantic_segmentor_patches (
26
+ sample_patch1 : Path , sample_patch2 : Path , tmp_path : Path
27
+ ) -> None :
24
28
"""Tests SemanticSegmentor on image patches."""
25
29
segmentor = SemanticSegmentor (
26
30
model = "fcn-tissue_mask" , batch_size = 32 , verbose = False , device = device
@@ -38,6 +42,9 @@ def test_semantic_segmentor_patches(sample_patch1: Path, sample_patch2: Path) ->
38
42
patch_mode = True ,
39
43
)
40
44
45
+ assert 0.24 < np .mean (output ["predictions" ][:]) < 0.25
46
+ assert 0.495 < np .mean (output ["probabilities" ][:]) < 0.505
47
+
41
48
assert (
42
49
tuple (segmentor ._ioconfig .patch_output_shape )
43
50
== output ["probabilities" ][0 ].shape [:- 1 ]
@@ -46,3 +53,35 @@ def test_semantic_segmentor_patches(sample_patch1: Path, sample_patch2: Path) ->
46
53
assert (
47
54
tuple (segmentor ._ioconfig .patch_output_shape ) == output ["predictions" ][0 ].shape
48
55
)
56
+
57
+ output = segmentor .run (
58
+ images = inputs ,
59
+ return_probabilities = True ,
60
+ return_labels = False ,
61
+ device = device ,
62
+ patch_mode = True ,
63
+ cache_mode = True ,
64
+ save_dir = tmp_path / "output0" ,
65
+ )
66
+
67
+ assert output == tmp_path / "output0" / "output.zarr"
68
+
69
+ output = zarr .open (output , mode = "r" )
70
+ assert 0.24 < np .mean (output ["predictions" ][:]) < 0.25
71
+ assert 0.495 < np .mean (output ["probabilities" ][:]) < 0.505
72
+
73
+ output = segmentor .run (
74
+ images = inputs ,
75
+ return_probabilities = False ,
76
+ return_labels = False ,
77
+ device = device ,
78
+ patch_mode = True ,
79
+ cache_mode = True ,
80
+ save_dir = tmp_path / "output1" ,
81
+ )
82
+
83
+ assert output == tmp_path / "output1" / "output.zarr"
84
+
85
+ output = zarr .open (output , mode = "r" )
86
+ assert 0.24 < np .mean (output ["predictions" ][:]) < 0.25
87
+ assert "probabilities" not in output .keys () # noqa: SIM118
0 commit comments