Skip to content

Commit 2f88cb7

Browse files
committed
🚧 Update script for annotation store
1 parent 816a568 commit 2f88cb7

File tree

3 files changed

+38
-57
lines changed

3 files changed

+38
-57
lines changed

tests/engines/test_semantic_segmentor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,30 @@ def test_semantic_segmentor_patches(
8787
assert "probabilities" not in output.keys() # noqa: SIM118
8888

8989

90+
def test_save_annotation_store(sample_patch1: Path, sample_patch2: Path, tmp_path: Path):
91+
segmentor = SemanticSegmentor(
92+
model="fcn-tissue_mask", batch_size=32, verbose=False, device=device
93+
)
94+
95+
inputs = [Path(sample_patch1), Path(sample_patch2)]
96+
output = segmentor.run(
97+
images=inputs,
98+
return_probabilities=False,
99+
return_labels=False,
100+
device=device,
101+
patch_mode=True,
102+
cache_mode=True,
103+
save_dir=tmp_path / "output1",
104+
output_type="annotationstore",
105+
)
106+
107+
assert output == tmp_path / "output1" / "output.zarr"
108+
109+
output = zarr.open(output, mode="r")
110+
assert 0.24 < np.mean(output["predictions"][:]) < 0.25
111+
assert "probabilities" not in output.keys() # noqa: SIM118
112+
113+
90114
def test_hovernet_dat() -> None:
91115
from tiatoolbox.utils.misc import store_from_dat
92116
from pathlib import Path

tiatoolbox/models/engine/semantic_segmentor_new.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77

88
import zarr
99
from typing_extensions import Unpack
10+
from pathlib import Path
1011

1112
from .patch_predictor import PatchPredictor, PredictorRunParams
1213
from ...utils.misc import dict_to_zarr, dict_to_store_semantic_segmentor
1314

1415
if TYPE_CHECKING: # pragma: no cover
1516
import os
16-
from pathlib import Path
1717

1818
import numpy as np
1919

tiatoolbox/utils/misc.py

Lines changed: 13 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,7 @@ def dict_to_store_semantic_segmentor(
12581258
12591259
"""
12601260
preds = patch_output["predictions"]
1261+
preds = preds[0]
12611262
layer_list = np.unique(preds)
12621263
layer_list = np.delete(layer_list, np.where(layer_list == 0))
12631264
layer_info_dict = {}
@@ -1273,76 +1274,32 @@ def dict_to_store_semantic_segmentor(
12731274
cv2.RETR_TREE,
12741275
cv2.CHAIN_APPROX_NONE,
12751276
)
1276-
for layer in contours:
1277-
coords = layer[:, 0, :]
1277+
for layer_ in contours:
1278+
coords = layer_[:, 0, :]
12781279
layer_info_dict[count] = {
12791280
"contours": coords,
1280-
"type": class_dict[type_class],
1281+
"type": "mask",
12811282
}
12821283
count += 1
12831284

12841285
origin = (0, 0)
1285-
1286-
annotations.append(
1287-
Annotation(
1288-
geometry=make_valid_poly(
1289-
feature2geometry(
1286+
scaled_coords = np.array([scale_factor * coords])
1287+
feature_geom = feature2geometry(
12901288
{
12911289
"type": "Polygon",
1292-
"coordinates": scale_factor * coords,
1290+
"coordinates": scaled_coords,
12931291
},
1294-
),
1292+
)
1293+
annotations.append(
1294+
Annotation(
1295+
geometry=make_valid_poly(
1296+
feature_geom,
12951297
origin=origin,
12961298
),
1297-
properties={},
1299+
properties={"type": "mask"},
12981300
)
12991301
)
13001302

1301-
# return layer_info_dict
1302-
1303-
# if "coordinates" not in patch_output:
1304-
# # we cant create annotations without coordinates
1305-
# msg = "Patch output must contain coordinates."
1306-
# raise ValueError(msg)
1307-
#
1308-
# # get relevant keys
1309-
# class_probs = get_zarr_array(patch_output.get("probabilities", []))
1310-
# preds = get_zarr_array(patch_output.get("predictions", []))
1311-
#
1312-
# patch_coords = np.array(patch_output.get("coordinates", []))
1313-
# if not np.all(np.array(scale_factor) == 1):
1314-
# patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp
1315-
# patch_coords = patch_coords.astype(float)
1316-
# labels = patch_output.get("labels", [])
1317-
# # get classes to consider
1318-
# if len(class_probs) == 0:
1319-
# classes_predicted = np.unique(preds).tolist()
1320-
# else:
1321-
# classes_predicted = range(len(class_probs[0]))
1322-
#
1323-
# if class_dict is None:
1324-
# # if no class dict create a default one
1325-
# if len(class_probs) == 0:
1326-
# class_dict = {i: i for i in np.unique(np.append(preds, labels)).tolist()}
1327-
# else:
1328-
# class_dict = {i: i for i in range(len(class_probs[0]))}
1329-
#
1330-
# # find what keys we need to save
1331-
# keys = ["predictions"]
1332-
# keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output]
1333-
#
1334-
# # put patch predictions into a store
1335-
# annotations = patch_predictions_as_annotations(
1336-
# preds,
1337-
# keys,
1338-
# class_dict,
1339-
# class_probs,
1340-
# patch_coords,
1341-
# classes_predicted,
1342-
# labels,
1343-
# )
1344-
#
1345-
# store = SQLiteStore()
13461303
_ = store.append_many(annotations, [str(i) for i in range(len(annotations))])
13471304

13481305
# # if a save director is provided, then dump store into a file

0 commit comments

Comments
 (0)