Skip to content

Commit cf5b50e

Browse files
committed
✅ Update script for annotation store
1 parent 8a95948 commit cf5b50e

File tree

5 files changed

+206
-14
lines changed

5 files changed

+206
-14
lines changed

tests/engines/test_semantic_segmentor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,10 @@ def test_semantic_segmentor_patches(
8585
output = zarr.open(output, mode="r")
8686
assert 0.24 < np.mean(output["predictions"][:]) < 0.25
8787
assert "probabilities" not in output.keys() # noqa: SIM118
88+
89+
90+
# def test_hovernet_dat() -> None:
91+
# from tiatoolbox.utils.misc import store_from_dat
92+
# from pathlib import Path
93+
# path_to_file = Path.cwd().parent.parent / "output" / "0.dat"
94+
# out = store_from_dat(path_to_file, scale_factor=(1.0, 1.0))

tests/test_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,7 +1654,7 @@ def test_patch_pred_store() -> None:
16541654
"other": "other",
16551655
}
16561656

1657-
store = misc.dict_to_store(patch_output, (1.0, 1.0))
1657+
store = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0))
16581658

16591659
# Check that it is an SQLiteStore containing the expected annotations
16601660
assert isinstance(store, SQLiteStore)
@@ -1667,15 +1667,15 @@ def test_patch_pred_store() -> None:
16671667
patch_output.pop("coordinates")
16681668
# check correct error is raised if coordinates are missing
16691669
with pytest.raises(ValueError, match="coordinates"):
1670-
misc.dict_to_store(patch_output, (1.0, 1.0))
1670+
misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0))
16711671

16721672
patch_output = {
16731673
"predictions": [1, 0, 1],
16741674
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
16751675
"other": "other",
16761676
}
16771677

1678-
store = misc.dict_to_store(patch_output, (1.0, 1.0))
1678+
store = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0))
16791679

16801680
# Check that it is an SQLiteStore containing the expected annotations
16811681
assert isinstance(store, SQLiteStore)
@@ -1692,7 +1692,7 @@ def test_patch_pred_store_cdict() -> None:
16921692
"other": "other",
16931693
}
16941694
class_dict = {0: "class0", 1: "class1"}
1695-
store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict)
1695+
store = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0), class_dict=class_dict)
16961696

16971697
# Check that it is an SQLiteStore containing the expected annotations
16981698
assert isinstance(store, SQLiteStore)
@@ -1713,7 +1713,7 @@ def test_patch_pred_store_sf() -> None:
17131713
"probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]],
17141714
"labels": [1, 0, 1],
17151715
}
1716-
store = misc.dict_to_store(patch_output, (2.0, 2.0))
1716+
store = misc.dict_to_store_patch_predictions(patch_output, (2.0, 2.0))
17171717

17181718
# Check that its an SQLiteStore containing the expected annotations
17191719
assert isinstance(store, SQLiteStore)
@@ -1770,7 +1770,7 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None:
17701770
}
17711771
save_path = tmp_path / "patch_output" / "output.db"
17721772

1773-
store_path = misc.dict_to_store(patch_output, (1.0, 1.0), save_path=save_path)
1773+
store_path = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0), save_path=save_path)
17741774

17751775
print("Annotation store path: ", store_path)
17761776
assert Path.exists(store_path), "Annotation Store output file does not exist"
@@ -1788,7 +1788,7 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None:
17881788
patch_output.pop("coordinates")
17891789
# check correct error is raised if coordinates are missing
17901790
with pytest.raises(ValueError, match="coordinates"):
1791-
misc.dict_to_store(patch_output, (1.0, 1.0))
1791+
misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0))
17921792

17931793

17941794
def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None:
@@ -1804,7 +1804,7 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None:
18041804
# sends the path of a jpeg source image, expects .db file in the same directory
18051805
save_path = tmp_path / "patch_output" / "output.jpeg"
18061806

1807-
store_path = misc.dict_to_store(patch_output, (1.0, 1.0), save_path=save_path)
1807+
store_path = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0), save_path=save_path)
18081808

18091809
print("Annotation store path: ", store_path)
18101810
assert Path.exists(store_path), "Annotation Store output file does not exist"
@@ -1822,7 +1822,7 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None:
18221822
patch_output.pop("coordinates")
18231823
# check correct error is raised if coordinates are missing
18241824
with pytest.raises(ValueError, match="coordinates"):
1825-
misc.dict_to_store(patch_output, (1.0, 1.0))
1825+
misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0))
18261826

18271827

18281828
def test_torch_compile_already_compiled() -> None:

tiatoolbox/models/engine/engine_abc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset
2222
from tiatoolbox.models.models_abc import load_torch_model
2323
from tiatoolbox.utils.misc import (
24-
dict_to_store,
24+
dict_to_store_patch_predictions,
2525
dict_to_zarr,
2626
write_to_zarr_in_cache_mode,
2727
)
@@ -633,7 +633,7 @@ def save_predictions(
633633
processed_predictions: dict | Path,
634634
output_type: str,
635635
save_dir: Path | None = None,
636-
**kwargs: dict,
636+
**kwargs: EngineABCRunParams,
637637
) -> dict | AnnotationStore | Path:
638638
"""Save model predictions.
639639
@@ -679,7 +679,7 @@ def save_predictions(
679679
processed_predictions_path = processed_predictions
680680
processed_predictions = zarr.open(processed_predictions, mode="r")
681681

682-
out_file = dict_to_store(
682+
out_file = dict_to_store_patch_predictions(
683683
processed_predictions,
684684
scale_factor,
685685
class_dict,

tiatoolbox/models/engine/semantic_segmentor_new.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
from __future__ import annotations
44

5+
import shutil
56
from typing import TYPE_CHECKING
67

8+
import zarr
79
from typing_extensions import Unpack
810

911
from .patch_predictor import PatchPredictor, PredictorRunParams
12+
from ...utils.misc import dict_to_zarr, dict_to_store_semantic_segmentor
1013

1114
if TYPE_CHECKING: # pragma: no cover
1215
import os
@@ -297,6 +300,78 @@ def __init__(
297300
verbose=verbose,
298301
)
299302

303+
def save_predictions(
304+
self: PatchPredictor,
305+
processed_predictions: dict | Path,
306+
output_type: str,
307+
save_dir: Path | None = None,
308+
**kwargs: SemanticSegmentorRunParams,
309+
) -> dict | AnnotationStore | Path:
310+
"""Save semantic segmentation predictions to disk.
311+
312+
Args:
313+
processed_predictions (dict | Path):
314+
A dictionary or path to zarr with model prediction information.
315+
save_dir (Path):
316+
Optional output path to directory to save the patch dataset output to a
317+
`.zarr` or `.db` file, provided `patch_mode` is True. If the
318+
`patch_mode` is False then `save_dir` is required.
319+
output_type (str):
320+
The desired output type for resulting patch dataset.
321+
**kwargs (SemanticSegmentorRunParams):
322+
Keyword Args required to save the output.
323+
324+
Returns:
325+
dict or Path or :class:`AnnotationStore`:
326+
If the `output_type` is "AnnotationStore", the function returns
327+
the patch predictor output as an SQLiteStore containing Annotations
328+
for each or the Path to a `.db` file depending on whether a
329+
save_dir Path is provided. Otherwise, the function defaults to
330+
returning patch predictor output, either as a dict or the Path to a
331+
`.zarr` file depending on whether a save_dir Path is provided.
332+
333+
"""
334+
if (
335+
self.cache_mode or not save_dir
336+
) and output_type.lower() != "annotationstore":
337+
return processed_predictions
338+
339+
save_path = Path(kwargs.get("output_file", save_dir / "output.db"))
340+
341+
if output_type.lower() == "annotationstore":
342+
# scale_factor set from kwargs
343+
scale_factor = kwargs.get("scale_factor", (1.0, 1.0))
344+
# class_dict set from kwargs
345+
class_dict = kwargs.get("class_dict")
346+
347+
processed_predictions_path: str | Path | None = None
348+
349+
# Need to add support for zarr conversion.
350+
if self.cache_mode:
351+
processed_predictions_path = processed_predictions
352+
processed_predictions = zarr.open(processed_predictions, mode="r")
353+
354+
out_file = dict_to_store_semantic_segmentor(
355+
processed_predictions,
356+
scale_factor,
357+
class_dict,
358+
save_path,
359+
)
360+
if processed_predictions_path is not None:
361+
shutil.rmtree(processed_predictions_path)
362+
363+
return out_file
364+
365+
return (
366+
dict_to_zarr(
367+
processed_predictions,
368+
save_path,
369+
**kwargs,
370+
)
371+
if isinstance(processed_predictions, dict)
372+
else processed_predictions
373+
)
374+
300375
def run(
301376
self: SemanticSegmentor,
302377
images: list[os | Path | WSIReader] | np.ndarray,

tiatoolbox/utils/misc.py

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,116 @@ def patch_predictions_as_annotations(
12281228
return annotations
12291229

12301230

1231+
def dict_to_store_semantic_segmentor(
1232+
patch_output: dict | zarr.group,
1233+
scale_factor: tuple[float, float],
1234+
class_dict: dict | None = None,
1235+
save_path: Path | None = None,
1236+
) -> AnnotationStore | Path:
1237+
"""Converts output of TIAToolbox SemanticSegmentor engine to AnnotationStore.
1238+
1239+
Args:
1240+
patch_output (dict | zarr.Group):
1241+
A dictionary with "probabilities", "predictions", and "labels" keys.
1242+
scale_factor (tuple[float, float]):
1243+
The scale factor to use when loading the
1244+
annotations. All coordinates will be multiplied by this factor to allow
1245+
conversion of annotations saved at non-baseline resolution to baseline.
1246+
Should be model_mpp/slide_mpp.
1247+
class_dict (dict):
1248+
Optional dictionary mapping class indices to class names.
1249+
save_path (str or Path):
1250+
Optional Output directory to save the Annotation
1251+
Store results.
1252+
1253+
Returns:
1254+
(SQLiteStore or Path):
1255+
An SQLiteStore containing Annotations for each patch
1256+
or Path to file storing SQLiteStore containing Annotations
1257+
for each patch.
1258+
1259+
"""
1260+
preds = patch_output["predictions"]
1261+
layer_list = np.unique(preds)
1262+
layer_list = np.delete(layer_list, np.where(layer_list == 0))
1263+
layer_info_dict = {}
1264+
count = 1
1265+
1266+
for type_class in layer_list:
1267+
layer = np.where(preds == type_class, 1, 0).astype("uint8")
1268+
contours, _ = cv2.findContours(
1269+
layer.astype("uint8"),
1270+
cv2.RETR_TREE,
1271+
cv2.CHAIN_APPROX_NONE,
1272+
)
1273+
for layer in contours:
1274+
coords = layer[:, 0, :]
1275+
layer_info_dict[count] = {
1276+
"contours": coords,
1277+
"type": class_dict[type_class],
1278+
}
1279+
count += 1
1280+
1281+
# return layer_info_dict
1282+
1283+
# if "coordinates" not in patch_output:
1284+
# # we cant create annotations without coordinates
1285+
# msg = "Patch output must contain coordinates."
1286+
# raise ValueError(msg)
1287+
#
1288+
# # get relevant keys
1289+
# class_probs = get_zarr_array(patch_output.get("probabilities", []))
1290+
# preds = get_zarr_array(patch_output.get("predictions", []))
1291+
#
1292+
# patch_coords = np.array(patch_output.get("coordinates", []))
1293+
# if not np.all(np.array(scale_factor) == 1):
1294+
# patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp
1295+
# patch_coords = patch_coords.astype(float)
1296+
# labels = patch_output.get("labels", [])
1297+
# # get classes to consider
1298+
# if len(class_probs) == 0:
1299+
# classes_predicted = np.unique(preds).tolist()
1300+
# else:
1301+
# classes_predicted = range(len(class_probs[0]))
1302+
#
1303+
# if class_dict is None:
1304+
# # if no class dict create a default one
1305+
# if len(class_probs) == 0:
1306+
# class_dict = {i: i for i in np.unique(np.append(preds, labels)).tolist()}
1307+
# else:
1308+
# class_dict = {i: i for i in range(len(class_probs[0]))}
1309+
#
1310+
# # find what keys we need to save
1311+
# keys = ["predictions"]
1312+
# keys = keys + [key for key in ["probabilities", "labels"] if key in patch_output]
1313+
#
1314+
# # put patch predictions into a store
1315+
# annotations = patch_predictions_as_annotations(
1316+
# preds,
1317+
# keys,
1318+
# class_dict,
1319+
# class_probs,
1320+
# patch_coords,
1321+
# classes_predicted,
1322+
# labels,
1323+
# )
1324+
#
1325+
# store = SQLiteStore()
1326+
# _ = store.append_many(annotations, [str(i) for i in range(len(annotations))])
1327+
#
1328+
# # if a save director is provided, then dump store into a file
1329+
# if save_path:
1330+
# # ensure parent directory exists
1331+
# save_path.parent.absolute().mkdir(parents=True, exist_ok=True)
1332+
# # ensure proper db extension
1333+
# save_path = save_path.parent.absolute() / (save_path.stem + ".db")
1334+
# store.dump(save_path)
1335+
# return save_path
1336+
#
1337+
# return store
1338+
1339+
1340+
12311341
def get_zarr_array(zarr_array: zarr.core.Array | np.ndarray | list) -> np.ndarray:
12321342
"""Converts a zarr array into a numpy array."""
12331343
if isinstance(zarr_array, zarr.core.Array):
@@ -1236,13 +1346,13 @@ def get_zarr_array(zarr_array: zarr.core.Array | np.ndarray | list) -> np.ndarra
12361346
return np.array(zarr_array).astype(float)
12371347

12381348

1239-
def dict_to_store(
1349+
def dict_to_store_patch_predictions(
12401350
patch_output: dict | zarr.group,
12411351
scale_factor: tuple[float, float],
12421352
class_dict: dict | None = None,
12431353
save_path: Path | None = None,
12441354
) -> AnnotationStore | Path:
1245-
"""Converts (and optionally saves) output of TIAToolbox engines as AnnotationStore.
1355+
"""Converts output of TIAToolbox PatchPredictor engines to AnnotationStore.
12461356
12471357
Args:
12481358
patch_output (dict | zarr.Group):

0 commit comments

Comments
 (0)